In [None]:
from __future__ import print_function
%matplotlib inline  
import matplotlib.pyplot as plt

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np

from data import DataSource

In [None]:
filterCount=500
#del dataSource
dataSource = DataSource(
    '/home/ihradis/projects/2016-07-07_JPEG/data/skyscraper.tst', 
    '/home/ihradis/projects/2016-07-07_JPEG/data/',
    minSize=250, length=[15,15], filterCount=filterCount) 
print(len(dataSource.images))

In [None]:
class DeblurClassic(nn.Module):

    def __init__(self, image, psf):
        super(DeblurClassic, self).__init__()
        self.image = nn.Parameter(image.clone(), requires_grad=True)
        self.psf = Variable(psf.view(3, 1, psf.shape[2], psf.shape[3]).clone())
        self.c0 = Variable(torch.FloatTensor([0])).view(1,1,1,1)
        self.c1 = Variable(torch.FloatTensor([1])).view(1,1,1,1)
                              
    def forward(self):
        x = F.conv2d(self.image, self.psf, groups=3)
        dx = torch.abs(self.image[:,:,1:,:]-self.image[:,:,:-1,:])**1.2
        dy = torch.abs(self.image[:,:,:,1:]-self.image[:,:,:,:-1])**1.2
        dd = torch.sum((dx + dy)) / (dx.data.shape[1] * dx.data.shape[2] * dx.data.shape[3])
        return x, dd
    
    def constrain(self):
        self.image.data[...] = torch.max(self.image, self.c0).data
        self.image.data[...] = torch.min(self.image, self.c1).data

        
        
class DeblurGAN(nn.Module):
    def __init__(self, images, psf):
        super(DeblurGAN, self).__init__()
        self.images = nn.Parameter(images.clone(), requires_grad=True)
        self.psf = Variable(psf.view(
            psf.shape[0]*psf.shape[1], 1, 
            psf.shape[2], psf.shape[3]).clone()).cuda()
        self.c0 = Variable(torch.FloatTensor([0])).view(1,1,1,1).cuda()
        self.c1 = Variable(torch.FloatTensor([1])).view(1,1,1,1).cuda()
        self.normPower = 1.2
        self.register_buffer('psf', self.psf)
        self.register_buffer('c0', self.c0)
        self.register_buffer('c1', self.c1)
        print(self.images.data[0,0,0,0], self.psf[0,0,0,0])

    def forward(self):
        #if self.use_gpu:
        #    self.images = self.images.cuda()
        #    self.psf = self.psf.cuda()
        
        x = self.images.view(1, -1, self.images.data.shape[2], self.images.data.shape[3])
        x = F.conv2d( x, self.psf, groups=self.psf.data.shape[0])
        x = x.view(-1, 3, x.data.shape[2], x.data.shape[3])
        dx = torch.abs(self.images[:,:,1:,:]-self.images[:,:,:-1,:])**self.normPower
        dy = torch.abs(self.images[:,:,:,1:]-self.images[:,:,:,:-1])**self.normPower
        dd = torch.sum((dx + dy)) / (dx.data.shape[0] * dx.data.shape[1] * dx.data.shape[2] * dx.data.shape[3])
        return x, dd
    
    def constrain(self):
        self.images.data[...] = torch.max(self.images, self.c0).data
        self.images.data[...] = torch.min(self.images, self.c1).data

In [None]:
class DeblurDiscriminator(nn.Module):
    def __init__(self):
        super(DeblurDiscriminator, self).__init__()
        baseCount = 32
        self.convLayers = nn.Sequential(
            nn.Conv2d(3, baseCount, 7), nn.PReLU(),
            nn.Conv2d(baseCount, baseCount, 3), nn.PReLU(),
            nn.Conv2d(baseCount, 1, 3),
            nn.AdaptiveAvgPool2d((1,1))
        )
                              
    def forward(self, x):
        x = self.convLayers(x).view(-1, 2)
        x = F.sigmoid(x)
        return x


In [None]:
def getData(reader, batchSize, cropSize, maxNoise = 0.08):
    data, idx, psf = reader.getBatch(batchSize, cropSize)
    data = (data-0.1) / 0.8
    
    data = Variable(torch.Tensor(data.astype(dtype=np.float32)).clone()).cuda()
    psf = Variable(torch.Tensor(psf.astype(dtype=np.float32)).clone()).cuda()
    blurred = F.conv2d(
        data.view(1, 3*batchSize, data.data.shape[2], data.data.shape[3]),
        psf.view(3*batchSize, 1, psf.data.shape[2], psf.data.shape[3]),
        groups=3*batchSize)
    
    blurred = blurred.view(batchSize, 3, blurred.data.shape[2], blurred.data.shape[3])
    
    if maxNoise > 0:
        #noiseEnergy = torch.FloatTensor(blurred.data.shape[0]).uniform_(0, maxNoise).view(-1, 1, 1, 1)
        blurred += Variable(torch.cuda.FloatTensor(blurred.data.shape).normal_() * maxNoise)
    
    return blurred, data, psf, idx 

In [None]:
class ImageRepository(Object)
    def __init__(self, dataReader, size=128, activeMemory=500, iterations=100):
        self.reader = dataReader
        criterion = torch.nn.MSELoss()

    def sharpen(self, )
        iterations
        optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
        
        net = DeblurClassic(blurred.data, psf.data)
        net.cuda()
        

In [None]:
bunchSize = 16
resolution = 128
bunchCount = 10
noise = 0.004

deblurNets = []
originalImages = []
blurredImages = []
imageOptimizers = []
for i in range(bunchCount):
    blurred, sharp, psf, fid = getData(dataSource, bunchSize, resolution, noise)
    deblurNets.append(DeblurGAN(blurred.data.cpu(), psf.data.cpu()))
    originalImages.append(sharp.data)
    blurredImages.append(blurred.data)
    imageOptimizers.append(torch.optim.Adam(deblurNets[-1].parameters(), lr=0.001))

for i in range(bunchCount):
    deblurNets[i].cuda()

In [None]:
discriminator = DeblurDiscriminator()
discriminator.cuda()
dOptimizator = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
dCriterion = nn.BCELoss()
imgCriterion = nn.MSELoss()

In [None]:
def cropTensor(what, to):
    b = ((what.data.shape[2] - to.data.shape[2])/2, 
      (what.data.shape[3] - to.data.shape[3])/2)
    return what[:, :, b[0]:-b[0], b[1]:-b[1]]
from helper import collage

accCount = 0
accDLoss = 0
accRecLoss = 0
accGenLoss = 0
for i in range(1000000):
    # train discriminator
    bunch = np.random.randint(bunchCount)
    out, dd = deblurNets[bunch]()
    
    #burred, sharp, psf, fid = getData(dataSource, bunchSize, resolution, noise)
    sharp = Variable(originalImages[bunch])
    sharp = cropTensor(sharp, out)
    batch = Variable(torch.cat([sharp.data.cuda(), out.data]))
    labels = Variable(torch.cat([torch.ones(sharp.data.shape[0])*0.9, 
                     torch.zeros(sharp.data.shape[0])]))

    dOptimizator.zero_grad()
    dOut = discriminator(batch)
    dLoss = dCriterion(dOut, labels.cuda())
    dLoss.backward()
    dOptimizator.step()
    
    accCount += 1
    accDLoss += dLoss.data[0]
    
    # update image
    blurred = Variable(blurredImages[bunch])
    b = ((blurred.data.shape[2] - out.data.shape[2])/2, 
        (blurred.data.shape[3] - out.data.shape[3])/2)
    imageOptimizers[bunch].zero_grad()
    
    genLoss = dCriterion(discriminator(out),  Variable(torch.ones(sharp.data.shape[0])).cuda())
    recLoss = imgCriterion(out, blurred[:,:,b[0]:-b[0],b[1]:-b[1]])
    imgLoss = recLoss + 0.0005*genLoss + 0.0005*dd
    imgLoss.backward()
    imageOptimizers[bunch].step()
    deblurNets[bunch].constrain()

    #blurred = Variable(blurredImages[bunch])
    #imageOptimizers[bunch].zero_grad()
    #imgLoss = imgCriterion(out, blurred[:,:,b[0]:-b[0],b[1]:-b[1]]) + 0.00001*dd
    #imgLoss.backward()
    #imageOptimizers[bunch].step()
    #deblurNets[bunch].constrain()
    accRecLoss += recLoss.data[0]
    accGenLoss += genLoss.data[0]
    
    if (i % 1000) == 0:
        print(accDLoss / accCount, accRecLoss / accCount, accGenLoss / accCount)
        accDLoss = 0
        accRecLoss = 0
        accGenLoss = 0
        accCount = 0
        
        fig = plt.figure(figsize=(14, 10), dpi=80, facecolor='w', edgecolor='k')
        plt.subplot(1,2,1)
        plt.imshow(collage(originalImages[bunch].cpu().numpy()[0:4,::-1, :, :]))
        plt.subplot(1,2,2)
        plt.imshow(collage(deblurNets[bunch].images.data[0:4].cpu().numpy()[:,::-1,:,:]))
        plt.draw()
        plt.show()    
    

In [None]:
blurred, sharp, psf, fid = getData(dataSource, 1, 512, 0.02)

fig = plt.figure(figsize=(14, 10), dpi=80, facecolor='w', edgecolor='k')
plt.subplot(1, 2, 1)
plt.imshow(sharp.data.cpu().numpy().transpose(0, 2, 3, 1)[0][:,:,::-1])
plt.subplot(1, 2, 2)
plt.imshow(blurred.data.cpu().numpy().transpose(0, 2, 3, 1)[0][:,:,::-1])
        
        
plt.draw()

plt.show()



In [None]:
net = DeblurClassic(blurred.data, psf.data)
net.cuda()
net = DeblurClassic(blurred.data, psf.data)


In [None]:
net.cuda()

blurred = blurred.cuda()
psf = psf.cuda()

out, dd = net()
b = ((blurred.data.shape[2] - out.data.shape[2])/2, 
      (blurred.data.shape[3] - out.data.shape[3])/2)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
#optimizer = torch.optim.SGD(net.parameters(), lr=200)
#optimizer = torch.optim.LBFGS(net.parameters(), lr=1.8, max_iter=100, 
#                 tolerance_grad=1e-8, tolerance_change=1e-15,)


for i in range(101):
    
    net.constrain()
    def lossEvaluator():
        optimizer.zero_grad()
        out, dd = net()
        loss = criterion(out, blurred[:,:,b[0]:-b[0],b[1]:-b[1]]) + 0.0000005*dd
        loss.backward()
        return loss
    loss = optimizer.step(lossEvaluator)
    
    if i % 100 == 0:
        net.constrain()
        print(loss)
        fig = plt.figure(figsize=(14, 10), dpi=80, facecolor='w', edgecolor='k')
        plt.subplot(1,2,1)
        plt.imshow(sharp.data.cpu().numpy().transpose(0, 2, 3, 1)[0][:,:,::-1])
        #plt.imshow(out.data.cpu().numpy().transpose(0, 2, 3, 1)[0][:,:,::-1])
        plt.subplot(1,2,2)
        plt.imshow(net.image.data.cpu().numpy().transpose(0, 2, 3, 1)[0][:,:,::-1])
        plt.draw()
        plt.show()
    
    

out, dd = net()
plt.imshow(out.data.cpu().numpy().transpose(0, 2, 3, 1)[0])

In [None]:
a = Variable(torch.Tensor(2)).max(Variable(torch.Tensor(1)+1))