In [None]:
from __future__ import print_function
%matplotlib inline  
import matplotlib.pyplot as plt

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np

In [None]:
from helper import generateMotionBlurPSF
import os
class DataSource(object):
    def __init__(self, imageSource, path='', length=[0,15], orientation=[0,180], filterCount=2000, minSize=100):
        print('Generating filters.')
        self.generateFilters(length, orientation, filterCount)
        print('Generating filters --- DONE.')

        self.minSize = minSize + self.filters[0].shape[2]
        print('Reading images.')
        self.readImages(imageSource, path, self.minSize)
        print('Reading images --- DONE.')
                
        
    def readImages(self, imageSource, path, minSize):
        self.images = []
        with open(imageSource, 'r') as f:
            for line in f:
                line = line.strip()
                try:
                    newImage = cv2.imread(os.path.join(path, line)).astype(np.float32)
                    newImage /= 256.0
                    newImage *= 0.8
                    newImage += 0.1
                    if len(newImage) == 2:
                        newImage = np.expand_dims(newImage, axis=2)
                        newImage = np.repeat(newImage, 3, axis = 2)
                    if newImage.shape[0] > minSize and newImage.shape[1] > minSize:
                        self.images.append(newImage.transpose(2,0,1))
                    else:
                        print('Warning: Image is too small "{}".'.format(line))
                except:
                    print('ERROR: While reading image "{}".'.format(line))

                    
    def generateFilters(self, length, orientation, filterCount):
        self.filters = []
        for i in range(filterCount):
            #o = (orientation[1] - orientation[0]) * float(i) / filterCount # 
            o = (orientation[1] - orientation[0])* np.random.ranf() + orientation[0]
            l = (length[1] - length[0]) * np.random.ranf() + length[0]
            #l = length[1] #
            psf = generateMotionBlurPSF(o, l)
            border = int((length[1] - psf.shape[0]) / 2)
            psf = np.pad(psf, [(border,border), (border,border)], mode='constant')
            psf = np.expand_dims(psf, axis=0)
            psf = np.repeat(psf, 3, axis = 0)
            self.filters.append(psf)
        self.filters = np.stack(self.filters, axis=0)
        
                    
    def getBatch(self, count=32, cropSize=100):
        
        cropSize = cropSize + self.filters[0].shape[2]
        
        idx = np.random.choice(len(self.images), count)
        images = [self.images[i] for i in idx]
        outImages = []
        for image in images:
            i1 = np.random.randint(image.shape[1] - cropSize)
            i2 = np.random.randint(image.shape[2] - cropSize)
            outImages.append(image[:, i1:i1+cropSize, i2:i2+cropSize])
        data = np.stack(outImages)
        
        idx = np.random.choice(self.filters.shape[0], count)
        #idx = np.arange(self.filters.shape[0])
        psf = self.filters[idx]
        
        return data, idx, psf
        
        

In [None]:
filterCount=10000
#del dataSource
dataSource = DataSource(
    '/home/ihradis/projects/2016-07-07_JPEG/data/skyscraper.15k', 
    '/home/ihradis/projects/2016-07-07_JPEG/data/',
    minSize=150, length=[1,15], filterCount=filterCount) 
print(len(dataSource.images))

In [None]:
from helper import collage
import time


In [None]:
class PsfNet(nn.Module):

    def __init__(self, input_img, embeddingDim):
        super(PsfNet, self).__init__()

        self.filterSizes = [5, 3, 3]
        self.filterCounts = [24, 48, 64]
        self.strides = [1, 1]
        self.fcSizes = [128, embeddingDim]
        
        lastChannels = input_img.data.shape[1]
        self.convLayers = []
        for fSize, fCount, stride in zip(self.filterSizes, self.filterCounts, self.strides):
            self.convLayers.append(nn.Conv2d(in_channels=lastChannels, 
                                             out_channels=fCount, 
                                             kernel_size=fSize, 
                                             stride=stride))
            lastChannels = fCount
            self.convLayers.append(nn.BatchNorm2d(num_features=lastChannels, momentum=0.8))
            self.convLayers.append(nn.ReLU())
        
        self.convLayers = nn.Sequential(*self.convLayers).cuda()
        out = self.convLayers(input_img)
        print('Conv output shape:', out.data.shape)
        
        lastChannels = out.data.shape[1]*out.data.shape[2]*out.data.shape[3]
        self.fcLayers = []
        for size in self.fcSizes[:-1]:
            self.fcLayers.append(nn.Linear(lastChannels, size))
            lastChannels = size
            #self.fcLayers.append(nn.BatchNorm2d(num_features=lastChannels, momentum=0.3))
            self.fcLayers.append(nn.ReLU())
            
        self.fcLayers.append(nn.Linear(lastChannels, self.fcSizes[-1]))
        lastChannels = self.fcSizes[-1]
           
        self.fcLayers = nn.Sequential(*self.fcLayers).cuda()
        
        self.outBNorm = nn.BatchNorm1d(lastChannels, momentum=0.8)
        
    def forward(self, input_x):
        x = input_x
        x = self.convLayers(x)
        x = x.view(x.data.shape[0], -1)
        x = self.fcLayers(x)
        x = F.normalize(x, p=2, dim=1, eps=1e-12)
        x = self.outBNorm(x)
        return x

class LinearFilterNet(nn.Module):

    def __init__(self, filterSizes, filterCounts):
        super(LinearNet, self).__init__()

        lastChannels = 3
        self.layers = []
        for fSize, fCount in zip(filterSizes, filterCounts):
            self.layers.append(nn.Conv2d(lastChannels, fCount, fSize))
            lastChannels = fCount

        self.layers = nn.ModuleList(self.layers)
                              
    def forward(self, input_x, psf):
        x = input_x
        for l in list(self.layers)[:-1]:
            x = F.tanh(l(x))
        x = self.layers[-1](x)
        
        b = (input_x.data.shape[2] - x.data.shape[2]) / 2
        input_x = input_x[:, :, b:-b, b:-b]
        x = x + input_x 
        
        return x
    
    
class WeightNet(nn.Module):
    def __init__(self, inDim, outDims):
        super(WeightNet, self).__init__()
 
        self.expandModule = []
        self.expandModule.append(nn.Linear(inDim, 128))
        self.expandModule.append(nn.BatchNorm2d(num_features=128, momentum=0.8))
        self.expandModule.append(nn.ReLU())
        self.expandModule = nn.Sequential(*self.expandModule)
        
        self.weightModules = nn.ModuleList()
        for outDim in outDims:
            modul = []
            modul.append(nn.Linear(128, outDim))
            modul.append(nn.Softmax())
            self.weightModules.append(nn.Sequential(*modul))
            
    def forward(self, input_x):
        x = input_x 
        x = self.expandModule(x)
        
        outputs = []
        for modul in self.weightModules:
            outputs.append(modul(x))
        
        return outputs
        
        

class DeconvNet(nn.Module):
    def __init__(self, psfNet, psfDim, filterSizes, filterCounts):
        super(DeconvNet, self).__init__()

        self.psfNet = psfNet
        
        self.weightNet = WeightNet(psfDim, filterCounts)
        
        lastChannels = 3
        self.layers = nn.ModuleList()
        self.BNlayers = nn.ModuleList()
        for fSize, fCount in zip(filterSizes, filterCounts):
            self.layers.append(nn.Conv2d(lastChannels, fCount, fSize))
            lastChannels = fCount
            self.BNlayers.append(nn.BatchNorm2d(num_features=lastChannels, momentum=0.5))

        self.lastLayer = nn.Conv2d(lastChannels, 3, 3)
            
    def forward(self, input_x, psf):
        x = input_x
        emb = self.psfNet(psf)
        weights = self.weightNet(emb)

        for i in range(len(self.layers)):
            x = self.layers[i](x)
            x = x * weights[i].view(x.data.shape[0], x.data.shape[1], 1, 1)
            x = self.BNlayers[i](x) 
            x = F.tanh(x)

        x = self.lastLayer(x)
        b = (input_x.data.shape[2] - x.data.shape[2]) / 2
        input_x = input_x[:, :, b:-b, b:-b]
        x = x + input_x 
        
        return x

In [None]:
class LinearNet(nn.Module):

    def __init__(self, filterSizes, filterCounts):
        super(LinearNet, self).__init__()

        lastChannels = 3
        self.layers = []
        for fSize, fCount in zip(filterSizes, filterCounts):
            self.layers.append(nn.Conv2d(lastChannels, fCount, fSize))
            lastChannels = fCount

        self.layers = nn.ModuleList(self.layers)
                              
    def forward(self, input_x):
        x = input_x
        for l in list(self.layers)[:-1]:
            x = F.tanh(l(x))
        x = self.layers[-1](x)
        
        b = (input_x.data.shape[2] - x.data.shape[2]) / 2
        input_x = input_x[:, :, b:-b, b:-b]
        x = x + input_x 
        
        return x
                



In [None]:
class Net(nn.Module):

    def __init__(self, ):
        super(Net, self).__init__()
        
        self.embDim = 32
        self.filterSize = 25
        self.filterCount = 3
        
        self.psfEmbed = nn.Embedding(
            maxFilterID, self.filterCount*self.filterSize*self.filterSize*3)#self.embDim)
        #self.psfFC1 = torch.nn.Linear(self.embDim, 64, bias=False)
        #self.psfFC2 = torch.nn.Linear(
        #    64, 
        #    self.filterCount*self.filterSize*self.filterSize*3, 
        #                              bias=False)
        
        
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        #self.conv1 = nn.Conv2d(self.filterCount, 64, 1)
        #self.conv2 = nn.Conv2d(64, 3, 7)
        
    def computeEmb(self, psfIdx):
        batchSize = psfIdx.data.shape[0]
        f = self.psfEmbed(psfIdx)
        #f = f.view(
        #    [batchSize, self.filterCount, self.filterSize*self.filterSize*3]).transpose(0,1).clone()
        
        #f = f.view(f.data.shape[0], f.data.shape[2])
        #f = F.relu(self.psfFC1(f))
        #f = self.psfFC2(f)
        #f = F.softmax(f)
        #f = f.view(batchSize, self.filterCount*3, self.filterSize, self.filterSize)
        
        #f = f.view(batchSize, self.filterCount, 3 , self.filterSize, self.filterSize)
        
        return f
    
    def forward(self, input_x, psf):
        batchSize = psfIdx.data.shape[0]
        
        f = self.computeEmb(psfIdx)
        x = input_x
        res = []
        for i in range(batchSize):
            img = x[i:i+1]
            w = f[i].view(self.filterCount, 3, self.filterSize, self.filterSize)
            res.append(F.conv2d(img, w))
        res = torch.cat(res, dim=0)

        #x = x.view(1, x.data.shape[0]*x.data.shape[1],
        #          x.data.shape[2], x.data.shape[3])
        #x = F.tanh(res)
        #x = x.view(batchSize, self.filterCount, x.data.shape[2], x.data.shape[3])
        
        #x = F.relu(self.conv1(x))
        #x = self.conv2(x)
        
        #b = (input_x.data.shape[2] - x.data.shape[2]) / 2
        #input_x = input_x[:, :, b:-b, b:-b]
        #x = x + input_x 
        
        return res


In [None]:
class PsfNet(nn.Module):

    def __init__(self, input_img, embeddingDim):
        super(PsfNet, self).__init__()

        self.filterSizes = [5, 3, 3]
        self.filterCounts = [24, 48, 64]
        self.strides = [1, 2, 1]
        self.fcSizes = [128, embeddingDim]
        
        lastChannels = input_img.data.shape[1]
        self.convLayers = []
        for fSize, fCount, stride in zip(self.filterSizes, self.filterCounts, self.strides):
            self.convLayers.append(nn.Conv2d(in_channels=lastChannels, 
                                             out_channels=fCount, 
                                             kernel_size=fSize, 
                                             stride=stride))
            lastChannels = fCount
            self.convLayers.append(nn.BatchNorm2d(num_features=lastChannels, momentum=0.5))
            self.convLayers.append(nn.ReLU())
        
        self.convLayers = nn.Sequential(*self.convLayers).cuda()
        out = self.convLayers(input_img)
        print('Conv output shape:', out.data.shape)
        
        lastChannels = out.data.shape[1]*out.data.shape[2]*out.data.shape[3]
        self.fcLayers = []
        for size in self.fcSizes[:-1]:
            self.fcLayers.append(nn.Linear(lastChannels, size))
            lastChannels = size
            #self.fcLayers.append(nn.BatchNorm2d(num_features=lastChannels, momentum=0.3))
            self.fcLayers.append(nn.ReLU())
            
        self.fcLayers.append(nn.Linear(lastChannels, self.fcSizes[-1]))
        lastChannels = self.fcSizes[-1]
           
        self.fcLayers = nn.Sequential(*self.fcLayers).cuda()
        
        self.outBNorm = nn.BatchNorm1d(lastChannels, momentum=0.5)
        
    def forward(self, input_x):
        x = input_x
        x = self.convLayers(x)
        x = x.view(x.data.shape[0], -1)
        x = self.fcLayers(x)
        x = F.normalize(x, p=2, dim=1, eps=1e-12)
        x = self.outBNorm(x)
        return x

class LinearFilterNet(nn.Module):

    def __init__(self, filterSizes, filterCounts):
        super(LinearNet, self).__init__()

        lastChannels = 3
        self.layers = []
        for fSize, fCount in zip(filterSizes, filterCounts):
            self.layers.append(nn.Conv2d(lastChannels, fCount, fSize))
            lastChannels = fCount

        self.layers = nn.ModuleList(self.layers)
                              
    def forward(self, input_x, psf):
        x = input_x
        for l in list(self.layers)[:-1]:
            x = F.tanh(l(x))
        x = self.layers[-1](x)
        
        b = (input_x.data.shape[2] - x.data.shape[2]) / 2
        input_x = input_x[:, :, b:-b, b:-b]
        x = x + input_x 
        
        return x

In [None]:
class WeightNet(nn.Module):
    def __init__(self, inDim, outDims):
        super(WeightNet, self).__init__()
 
        self.expandModule = []
        self.expandModule.append(nn.Linear(inDim, 128))
        self.expandModule.append(nn.BatchNorm2d(num_features=128, momentum=0.5))
        self.expandModule.append(nn.ReLU())
        self.expandModule = nn.Sequential(*self.expandModule)
        
        self.weightModules = nn.ModuleList()
        for outDim in outDims:
            modul = []
            modul.append(nn.Linear(128, outDim))
            modul.append(nn.Softmax())
            self.weightModules.append(nn.Sequential(*modul))
            
    def forward(self, input_x):
        x = input_x 
        x = self.expandModule(x)
        
        outputs = []
        for modul in self.weightModules:
            outputs.append(modul(x))
        
        return outputs
        
        

class DeconvNet(nn.Module):
    def __init__(self, psfNet, psfDim, filterSizes, filterCounts):
        super(DeconvNet, self).__init__()

        self.psfNet = psfNet
        
        self.weightNet = WeightNet(psfDim, filterCounts)
        
        lastChannels = 3
        self.layers = nn.ModuleList()
        self.BNlayers = nn.ModuleList()
        for fSize, fCount in zip(filterSizes, filterCounts):
            self.layers.append(nn.Conv2d(lastChannels, fCount, fSize))
            lastChannels = fCount
            self.BNlayers.append(nn.BatchNorm2d(num_features=lastChannels, momentum=0.5))

        self.lastLayer = nn.Conv2d(lastChannels, 3, 3)
            
    def forward(self, input_x, psf):
        x = input_x
        emb = self.psfNet(psf)
        weights = self.weightNet(emb)

        for i in range(len(self.layers)):
            x = self.layers[i](x)
            x = x * weights[i].view(x.data.shape[0], x.data.shape[1], 1, 1)
            x = self.BNlayers[i](x) 
            x = F.tanh(x)

        x = self.lastLayer(x)
        b = (input_x.data.shape[2] - x.data.shape[2]) / 2
        input_x = input_x[:, :, b:-b, b:-b]
        x = x + input_x 
        
        return x

In [None]:

def getData(reader, batchSize, cropSize, maxNoise = 0):
    data, idx, psf = reader.getBatch(batchSize, cropSize)
    Tidx = Variable(torch.LongTensor(idx).view(batchSize, 1).cuda())
    Tdata = Variable(torch.Tensor(data.astype(dtype=np.float32)).cuda())
    Tpsf = Variable(torch.Tensor(psf.astype(dtype=np.float32)).cuda())
    blurred = F.conv2d(
        Tdata.view(1, 3*batchSize, Tdata.data.shape[2], Tdata.data.shape[3]),
        Tpsf.view(3*batchSize, 1, Tpsf.data.shape[2], Tpsf.data.shape[3]),
        groups=3*batchSize)

    blurred = blurred.view(batchSize, 3, blurred.data.shape[2], blurred.data.shape[3])
    
    if maxNoise > 0:
        noiseEnergy = torch.cuda.FloatTensor(blurred.data.shape[0]).uniform_(0, maxNoise).view(-1, 1, 1, 1)
        blurred += Variable(torch.cuda.FloatTensor(blurred.data.shape).normal_() * noiseEnergy)
    
    return blurred, Tdata, Tpsf, Tidx, 

psfEmbeddingDim = 64
blurred, sharp, psf, fid = getData(dataSource, 2, 150)
psfNet = PsfNet(psf, psfEmbeddingDim).cuda()
out = psfNet(psf)
print(out.data.shape)


In [None]:
net = DeconvNet(psfNet, psfEmbeddingDim, [35,1,1,3,3,3,3,3], [64,128,128,64,64,64,64,64])
print(net)
net.cuda()
net(blurred, psf)

lossHistory = []
lossPositions = []
iteration = 0

criterion = torch.nn.MSELoss()

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.0004)


In [None]:
maxNoise = 4.0 / 256.0
acc = 0
dataAcc = 0
iterationAcc = 0
targetSize = 28
testSize = 150
blurred, sharp, psf, fid = getData(dataSource, 1, testSize, maxNoise)
out = net(blurred, psf)
border = testSize - out.data.shape[2]
cropSize = targetSize + border
print('Crop size:', cropSize, 'border:', border)

In [None]:
d = torch.load('model_105999.mod')
net.load_state_dict(d)
iteration = 106000


In [None]:

maxNoise = 5.0 / 256.0
blurred, sharp, psf, fid = getData(dataSource, 1, testSize, maxNoise)
out = net(blurred, psf)
batchSize = 128
viewStep = 20

for i in range(1000000):
    iteration += 1
    optimizer.zero_grad()
    blurred, sharp, psf, fid = getData(dataSource, batchSize, cropSize, maxNoise)
    
    out = net(blurred, psf)
    b = ((sharp.data.shape[2] - out.data.shape[2])/2, 
              (sharp.data.shape[3] - out.data.shape[3])/2)
    loss = criterion(out, sharp[:,:,b[0]:-b[0],b[1]:-b[1]])
    loss.backward()
    optimizer.step()

    iterationAcc += 1
    acc += loss.data[0]

    b2 = ((sharp.data.shape[2] - blurred.data.shape[2])/2, 
              (sharp.data.shape[3] - blurred.data.shape[3])/2)
    dataLoss = criterion(blurred, sharp[:, :, b2[0]:-b2[0], b2[1]:-b2[1]])
    dataAcc += dataLoss.data[0]

    if iteration % viewStep == viewStep-1 and iterationAcc > viewStep / 3:
        #net.eval()
        acc /= iterationAcc
        dataAcc /= iterationAcc
        iterationAcc = 0
        print(iteration, acc, dataAcc)
        lossHistory.append(acc) 
        lossPositions.append(iteration)
        acc = 0
        dataAcc = 0 
        vizSize = 128
        vizBatchSize = 32
        
        blurred, sharp, psf, fid = getData(dataSource, vizBatchSize, vizSize+border, maxNoise)
        out = net(blurred, psf)
        res = out.data.cpu().numpy()
        print(res.min(), res.max())
        
        fig = plt.figure(figsize=(14, 10), dpi=80, facecolor='w', edgecolor='k')
        plt.subplot(1, 2, 1)
        res = np.maximum(np.minimum(res, 1.0), 0.0)
        recColl = collage(res)
        plt.imshow(recColl[:512,:512,::-1])
        
        b = (blurred.data.shape[2] - out.data.shape[2]) / 2
        blurred = blurred[:, :, b:-b, b:-b]
        blurred = blurred.data.cpu().numpy()

        plt.subplot(1, 2, 2)
        blurColl = collage(blurred)
        plt.imshow(blurColl[:512,:512,::-1])
        
        '''plt.subplot(1, 2, 2)
        filterCount = 16
        psfIdx = Variable(torch.LongTensor(range(filterCount)).view(filterCount, 1).cuda())
        f = net.computeEmb(psfIdx).view(filterCount*net.filterCount, 3, net.filterSize, net.filterSize)
        data = f.data.cpu()
        data -= data.min()
        data /= data.max()
        data = data.numpy()
        print(data.shape)
        img = collage(data)
        plt.imshow(img)
        print(img.shape)
        '''
        plt.draw()

        plt.show()

        plt.subplot(1, 2, 1)

        l0 = net.layers[0]
        l0 = l0.weight.data.cpu().numpy()
        l0 -= l0.min()
        l0 /= l0.max()

        filterColl = collage(l0)
        plt.imshow(filterColl[:,:,::-1])
        plt.subplot(1, 2, 2)
        plt.plot(lossPositions, lossHistory)
        plt.show()
        net.train()
        np.savez('loss_2.npz', lossPositions, lossHistory)
        cv2.imwrite('{:06d}_blur.png'.format(iteration), blurColl*256)
        cv2.imwrite('{:06d}_rec.png'.format(iteration), recColl*256)
        cv2.imwrite('filter_{:06d}.png'.format(iteration), filterColl*256)
        torch.save(net.state_dict(), 'model_{:06d}.mod'.format(iteration))
        


In [None]:
plt.plot(lossPositions[1:], lossHistory[1:])
print(iteration, iterationAcc)
l0 = net.layers[0]
l0 = l0.weight.data.cpu().numpy()
l0 -= l0.min()
l0 /= l0.max()

img = collage(l0)
#plt.imshow(img[:,:,::-1])


In [None]:

data = net.psfEmbed.weight.data.cpu()
data -= data.min()
data /= data.max()
data = data.numpy()
data = data.reshape(data.shape[0]*net.filterCount, 3, net.filterSize, net.filterSize)
print(data.shape)
img = collage(data)
plt.imshow(img)
plt.draw()


In [None]:
res = out.data.cpu().numpy()
img = collage(res)
print(img.shape, res.shape, data.shape)
plt.subplot(1, 2, 1)
plt.imshow(img/255.0)
plt.show()

In [None]:
psfIdx = Variable(Tidx.cuda())
f = net.psfEmbed(psfIdx)
f = f.view(f.data.shape[0], f.data.shape[2])
f = F.relu(net.psfFC1(f))
f = net.psfFC2(f)
f = F.softmax(f)
f = f.view(128, 3, net.filterSize, net.filterSize)

res = f.data.cpu().numpy()
plt.subplot(1, 2, 1)
img = collage(res[:16,:,:,:], True)
plt.imshow(img)
plt.show()

In [None]:
res.shape