In [None]:
from __future__ import print_function
%matplotlib inline  
import matplotlib.pyplot as plt

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np

In [None]:
from helper import generateMotionBlurPSF
import os
from dataHelper import generateShakePSF, generateDefocusPSF

def virtualCamera(rotXdeg, rotYdeg, rotZdeg, x, y, z, f, cropW, cropH):
    rotX = (rotXdeg)*np.pi/180
    rotY = (rotYdeg)*np.pi/180
    rotZ = (rotZdeg)*np.pi/180

    #Projection 2D -> 3D matrix
    A1= np.matrix([[1, 0, -x],
                   [0, 1, -y],
                   [0, 0, 0 ],
                   [0, 0, 1 ]], dtype=np.float64)

    # Rotation matrices around the X,Y,Z axis
    RX = np.matrix([[1,           0,            0, 0],
                    [0,np.cos(rotX),-np.sin(rotX), 0],
                    [0,np.sin(rotX),np.cos(rotX) , 0],
                    [0,           0,            0, 1]], dtype=np.float64)

    RY = np.matrix([[ np.cos(rotY), 0, np.sin(rotY), 0],
                    [            0, 1,            0, 0],
                    [ -np.sin(rotY), 0, np.cos(rotY), 0],
                    [            0, 0,            0, 1]], dtype=np.float64)

    RZ = np.matrix([[ np.cos(rotZ), -np.sin(rotZ), 0, 0],
                    [ np.sin(rotZ), np.cos(rotZ), 0, 0],
                    [            0,            0, 1, 0],
                    [            0,            0, 0, 1]], dtype=np.float64)

    #Composed rotation matrix with (RX,RY,RZ)
    R = RX * RY * RZ

    #Translation matrix on the Z axis change dist will change the height
    T = np.matrix([[1,0,0,0],
                   [0,1,0,0],
                   [0,0,1,z],
                   [0,0,0,1]], dtype=np.float64)

    
    #Camera Intrisecs matrix 3D -> 2Dnp.random.rand() * 0.6
    A2 = np.matrix([[f, 0, cropW/2, 0],
                   [0, f, cropH/2, 0],
                   [0, 0,       1, 0]], dtype=np.float64)

    # Final and overall transformation matrix
    H = A2 * T * (R * A1)
    return H
    
    # Apply matrix transformation
    #cv2.warpPerspective(src, H, (w, h), dst, cv2.INTER_CUBIC)

class DataSource(object):
    def __init__(self, imageSource, path='', filterCount=1000, cropCount=1000, cropSize=128):
        with open(imageSource, 'r') as f:
            self.fileNames = [os.path.join(path, line.strip()) for line in f]

        self.path = path
        self.cropSize = cropSize
        self.imgPos = 0
        self.xRSdev = 15
        self.yRSdev = 15
        self.zRSdev = 6
        self.cropPos = 0
        self.crops = torch.cuda.ByteTensor(cropCount, 3, cropSize, cropSize)
        
        self.maxShakeSize = 13
        self.maxDefocusSize = 7
        self.maxFilterSize = self.maxShakeSize + self.maxDefocusSize + 1
        self.filterPos = 0
        self.filters = torch.cuda.FloatTensor(filterCount, 1, self.maxFilterSize, self.maxFilterSize)
        
        self.rng = np.random.RandomState(1)

        
    def updateCrops(self, cropsPerImage):
        crops = []
        while(len(crops) < cropsPerImage):
            imageName = self.fileNames[self.imgPos]
            self.imgPos = (self.imgPos + 1) % len(self.fileNames)
            try:
                image = cv2.imread(os.path.join(self.path, imageName))
                if image.shape[0] * 1.25 < self.cropSize and image.shape[1] * 1.25 < self.cropSize:
                    continue

                for i in range(cropsPerImage):
                    rotXdeg = np.random.normal() * self.xRSdev
                    rotYdeg = np.random.normal() * self.yRSdev
                    rotZdeg = np.random.normal() * self.zRSdev
                    x = np.random.rand() * image.shape[1]
                    y = np.random.rand() * image.shape[0]

                    fov = (30 + np.random.rand() * 70) / 180 * np.pi
                    baseDist = (1700/2.0) / np.tan(fov/2.0)
                    f = baseDist
                    baseDist *= 0.7 + np.random.rand() * 0.6

                    H =  virtualCamera(rotXdeg, rotYdeg, rotZdeg, x, y, baseDist, f, self.cropSize, self.cropSize)
                    dst = np.matrix([
                        [0, 0, 1],
                        [self.cropSize - 1, 0, 1],
                        [self.cropSize - 1, self.cropSize - 1, 1],
                        [0, self.cropSize - 1, 1]], dtype=np.float32)

                    invH = np.linalg.inv(H)
                    src = invH * dst.T
                    src = src[0:2] / src[2:3]

                    # check if inside the source image
                    if( np.any(src < 0) or np.any(src[0] >= image.shape[1]) or np.any(src[1] >= image.shape[0])):
                        continue
                    
                    # crop 
                    crop = cv2.warpPerspective(image, H, (self.cropSize, self.cropSize), flags=cv2.INTER_LINEAR)
                    if crop.std() < 8:
                        continue
                    crops.append(crop)
            except:
                print('ERROR: While reading image "{}".'.format(imageName))

        for i in range(len(crops)):
            input = torch.ByteTensor(np.ascontiguousarray(crops[i].transpose(2, 0 ,1)))
            self.crops[self.cropPos].copy_(input, async=True, broadcast=False) 
            self.cropPos = (self.cropPos + 1) % self.crops.shape[0]

    def generateFilters(self, filterCount):
        filters = []
        while len(filters) < filterCount:
            shakeSize = int(np.random.rand() * (self.maxShakeSize + 1)) // 2 * 2 + 1
            defocusRadius = np.random.rand() * self.maxDefocusSize * 0.5

            psf, center = generateDefocusPSF(radius=defocusRadius)
            border = int((self.maxFilterSize - psf.shape[0]) / 2)
            psf = np.pad(psf, [(border,border), (border,border)], mode='constant')            

            if shakeSize > 0: 
                shakePsf = generateShakePSF(self.rng, resolution=shakeSize, halflife=0.75)
                psf = cv2.filter2D(psf, -1, shakePsf)

            psf /= (psf**1).sum()**1
            filters.append(psf.reshape(1, self.maxFilterSize, self.maxFilterSize))
                
        for i in range(len(filters)):
            input = torch.FloatTensor(filters[i])
            self.filters[self.filterPos].copy_(input, async=True, broadcast=False) 
            self.filterPos = (self.filterPos + 1) % self.filters.shape[0]
            
    def colorManipulation(self, data):
        self.colorSdev = 0.07
        self.contrastSdev = 0.5
        self.minColorSdev = 0.1
        self.gammaSdev = 0.2
        self.noiseSdev = 0.02
        
        # color and contrast
        colorCoef = torch.cuda.FloatTensor(data.shape[0], data.shape[1], 1, 1).normal_(std=self.colorSdev)
        contrast = torch.cuda.FloatTensor(data.shape[0], 1, 1, 1).normal_(std=self.contrastSdev).abs_()
        data *= torch.cuda.FloatTensor([2.0]).pow(colorCoef -contrast)

        # additive color
        data += torch.cuda.FloatTensor(data.shape[0], 1, 1, 1).normal_(std=self.minColorSdev).abs_()
        
        # noise
        noiseSdev = torch.cuda.FloatTensor(data.shape[0], 1, 1, 1).normal_(std=self.noiseSdev).abs_()
        data += torch.cuda.FloatTensor(*data.shape).normal_() * noiseSdev
        data.clamp_(0, 1.0)

        # gamma
        gamma = torch.cuda.FloatTensor(data.shape[0], 1, 1, 1).normal_(std=self.gammaSdev)
        gamma = torch.cuda.FloatTensor([2.0]).pow(gamma)
        data.pow_(gamma)
        
        return data

    def getBatch(self, count=32):
        
        cropIdx = np.random.choice(self.crops.shape[0], count)

        tmp = np.random.choice(self.filters.shape[0], count)
        filterIdx = []
        for v in tmp:
            filterIdx.append(v)
            filterIdx.append(v)
            filterIdx.append(v)
        
        cropIdx = torch.cuda.LongTensor(cropIdx)
        
        filterIdx = torch.cuda.LongTensor(filterIdx)
        
        crops = self.crops[cropIdx].type(torch.cuda.FloatTensor) / 256.0
        filters = self.filters[filterIdx].contiguous()
        
        vImg = Variable(crops.view(1, crops.shape[0]*3, crops.shape[2], crops.shape[3]))
        vFilt = Variable(filters.view(crops.shape[0]*3, 1, filters.shape[2], filters.shape[3]))
        blurred = F.conv2d(vImg, vFilt, bias=None, groups=crops.shape[0]*3)
        
        
        blurred = blurred.data
        blurred = blurred.view(crops.shape[0], 3, blurred.shape[2], blurred.shape[3])
        
        self.colorManipulation(blurred)
        #crops = crops * 0.8  + 0.1

        return crops, blurred, filters
        
        
        

In [None]:
filterCount=20000
cropCount=20000
#del dataSource
dataSource = DataSource(
    '/mnt/matylda1/hradis/2015-03-06_image_restoration/PDF/data/allImages.rnd', 
    '/mnt/matylda1/hradis/2015-03-06_image_restoration/PDF/data',
    filterCount=filterCount, cropCount=cropCount, cropSize=156) 

cropsPerImage=40
for i in range(5):
    dataSource.updateCrops(cropsPerImage)
while(dataSource.cropPos >=  4*cropsPerImage):
    dataSource.updateCrops(cropsPerImage)
    print(dataSource.cropPos)
    
dataSource.generateFilters( filterCount)
    

In [None]:
from helper import collage
img, blurred, psf = dataSource.getBatch(64)

data = psf.cpu().numpy()
img = collage(data[::3], normSamples=True)
print(img.shape)
plt.figure(figsize=(6,6))
plt.imshow(img[:,:,0])

data = blurred.cpu().numpy()
print('Min/max:', data.min(), data.max())
img = collage(data)
plt.figure(figsize=(12,12))
plt.imshow(img / img.max())

In [None]:
import time

In [None]:

class SqueezeExitationBlock(nn.Module):
    def __init__(self, inputChannels, filterSize=3, filterCount=32):
        super(SqueezeExitationBlock, self).__init__()
        
        if filterSize > 3:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=0,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)
        else:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=(filterSize-1)/2,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)

        self.sqeeze = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels=filterCount, 
                     out_channels=filterCount, 
                     kernel_size=1),
            nn.Sigmoid())

        self.output = nn.Sequential(
            nn.BatchNorm2d(num_features=filterCount, momentum=0.75),
            nn.ReLU()
            )
        
    def forward(self, input_x):
        act = self.layer1(input_x)
        weights = self.sqeeze(act)
        out = self.output(act * weights)
        return out

class ConvBlock(nn.Module):
    def __init__(self, inputChannels, filterSize=3, filterCount=32):
        super(ConvBlock, self).__init__()
        
        if filterSize > 3:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=0,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)
        else:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=(filterSize-1)/2,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)


        self.output = nn.Sequential(
            nn.BatchNorm2d(num_features=filterCount, momentum=0.75),
            nn.ReLU()
            )
        
    def forward(self, input_x):
        act = self.layer1(input_x)
        out = self.output(act)
        return out
    
class LinearBlock(nn.Module):

    def __init__(self, inputChannels, filterSizes=[3, 3], filterCounts=[32, 32]):
        super(LinearBlock, self).__init__()
        
        lastChannels = inputChannels
        convLayers = []
        for fSize, fCount in zip(filterSizes, filterCounts):
            #if inputChannels > 6:
            #    convLayers.append(torch.nn.Dropout2d(p=0.2))
            convLayers.append(
                ConvBlock(inputChannels=lastChannels, 
                                      filterCount=fCount, 
                                      filterSize=fSize))
            lastChannels = fCount

        self.convLayers = nn.Sequential(*convLayers)
        
    def forward(self, input_x):
        return self.convLayers(input_x)

class AggregationNet(nn.Module):
    def __init__(self, inputChannels):
        super(AggregationNet, self).__init__()
        self.compressBlockInfo = [(2, 16), (2, 24), (2, 32)]        
        self.compressBlocks = torch.nn.ModuleList()
   
        filterSizes = [13, 1, 3]    
        filterCounts = [32, 32, 16]
        self.compressBlocks.append(LinearBlock(inputChannels, filterSizes, filterCounts))
        lastChannels = filterCounts[-1]

        for layerCount, filterCount in self.compressBlockInfo[1:]:
            filterSizes = [3 for i in range(layerCount)]    
            filterCounts = [filterCount for i in range(layerCount)]
            self.compressBlocks.append(LinearBlock(lastChannels, filterSizes, filterCounts))
            lastChannels = filterCount
        
        self.outLayers = nn.Sequential(
            nn.Conv2d(in_channels=lastChannels, out_channels=256, kernel_size=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
            nn.ReLU(),
            );
        
    def forward(self, input_x):
        x = input_x
        for i, module in enumerate(self.compressBlocks):
            x = module(x)
            x = F.max_pool2d(x, 2, stride=2)
        x = self.outLayers(x)
            
        return x
    
class HourGlassNet(nn.Module):

    def __init__(self, inputChannels):
        super(HourGlassNet, self).__init__()
        
        self.aggregation = AggregationNet(inputChannels)
        
        
        self.compressBlockInfo = [(2, 24), (2, 32), (3, 64)]
        self.decompressBlockInfo = [(2, 32), (2, 24)]
        self.compressBlocks = torch.nn.ModuleList()
        self.decompressBlocks = torch.nn.ModuleList()


        self.blockMultipliers = torch.nn.ModuleList()
        self.blockBN = torch.nn.ModuleList()
        for fCount in [j[1] for j in self.compressBlockInfo + self.decompressBlockInfo]:
            self.blockMultipliers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels=128, out_channels=fCount, kernel_size=1),
                    nn.Softmax2d(),
                )
            )
            self.blockBN.append(
                 nn.BatchNorm2d(num_features=fCount, momentum=0.75))

        lastChannels = inputChannels

        filterSizes = [13, 1, 3]    
        filterCounts = [64, 32, 24]
        self.compressBlocks.append(LinearBlock(lastChannels, filterSizes, filterCounts))
        lastChannels = filterCounts[-1]

        for layerCount, filterCount in self.compressBlockInfo[1:]:
            filterSizes = [3 for i in range(layerCount)]    
            filterCounts = [filterCount for i in range(layerCount)]
            self.compressBlocks.append(LinearBlock(lastChannels, filterSizes, filterCounts))
            lastChannels = filterCount

        compressFilterCounts = [fc for lc, fc in self.compressBlockInfo[:-1]]

        for compressFC, (layerCount, filterCount) in zip(compressFilterCounts[::-1], self.decompressBlockInfo):
            filterSizes = [3 for i in range(layerCount)]    
            filterCounts = [filterCount for i in range(layerCount)]
            self.decompressBlocks.append(LinearBlock(lastChannels + compressFC, filterSizes, filterCounts))
            lastChannels = filterCount
        
        
        fSize = 3
        self.outputLayer = nn.Conv2d(in_channels=lastChannels, 
                                          padding=(fSize-1)/2,
                                          out_channels=inputChannels, 
                                          kernel_size=fSize)
        
    def forward(self, input_x):
        agg = self.aggregation(input_x)
        blockMultipliers = list(self.blockMultipliers)
        blockBN = list(self.blockBN)
        
        x = input_x
        compressStages = []
        for i, module in enumerate(self.compressBlocks):
            x = module(x)
            x = blockBN[0](x * blockMultipliers[0](agg))
            blockMultipliers = blockMultipliers[1:]
            blockBN = blockBN[1:]
            if i < len(self.compressBlocks) -1:
                compressStages.append(x)
                x = F.max_pool2d(x, 2, stride=2)
        
        for module, bypass in zip(self.decompressBlocks, compressStages[::-1]):
            x = F.upsample_nearest(x, scale_factor=2)
            x = torch.cat((x, bypass), dim=1)
            x = module(x)
            x = blockBN[0](x * blockMultipliers[0](agg))
            blockMultipliers = blockMultipliers[1:]
            blockBN = blockBN[1:]
            
        x = self.outputLayer(x)
        #x = F.sigmoid(x)        
        return x             

In [None]:
net = HourGlassNet(3)
print(net)
net.cuda()

lossHistory = []
lossPositions = []
iteration = 0



In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
lossHistory = []
lossPositions = []
iteration = 0

criterion = torch.nn.MSELoss().cuda()

In [None]:
acc = 0
dataAcc = 0
iterationAcc = 0

In [None]:
from functools import partial
def updateAll(source):
    source.stop = False
    while(not source.stop):
        source.generateFilters(20)
        source.updateCrops(20)
        time.sleep(0.3)
        
from threading import Thread
updateThread = Thread(target=partial(updateAll, dataSource))
updateThread.start()

In [None]:
continueIt = 0
if continueIt:
    iteration = continueIt
    data = np.load('loss_2.npz')
    lossPositions = data['arr_0'].tolist()
    lossHistory = data['arr_1'].tolist()
    d = torch.load('model_{:06d}.mod'.format(iteration))
    net.load_state_dict(d)

net = net.cuda()


In [None]:
iteration

In [None]:
batchSize = 32
viewStep = 500

print(iteration, dataSource.filterPos, dataSource.cropPos)
t1 = time.time()
lastIteration = iterationAcc
for i in range(1000000):
    iteration += 1
    optimizer.zero_grad()
    sharp, blurred, psf = dataSource.getBatch(batchSize)
    sharp = Variable(sharp)
    blurred = Variable(blurred)
    out = net(blurred)
    b = ((sharp.data.shape[2] - out.data.shape[2])/2, 
              (sharp.data.shape[3] - out.data.shape[3])/2)
    loss = criterion(out, sharp[:,:,b[0]:-b[0],b[1]:-b[1]])
    loss.backward()
    optimizer.step()

    iterationAcc += 1
    acc += loss.data[0]

    b2 = ((sharp.data.shape[2] - blurred.data.shape[2])/2, 
              (sharp.data.shape[3] - blurred.data.shape[3])/2)
    dataLoss = criterion(blurred, sharp[:, :, b2[0]:-b2[0], b2[1]:-b2[1]])
    dataAcc += dataLoss.data[0]

    if iteration % viewStep == viewStep-1 and iterationAcc > viewStep / 3:
        elapsed = time.time() - t1
        t1 = time.time()
        print('Iteration:', iteration, dataSource.filterPos, dataSource.cropPos)
        print('Img per second: ', iterationAcc * batchSize / elapsed)
        print('Batch per second: ', iterationAcc / elapsed)
        print('Elapsed time: ', elapsed)

        sharp, blurred, psf = dataSource.getBatch(batchSize)
        sharp = Variable(sharp)
        blurred = Variable(blurred)

        acc /= iterationAcc
        dataAcc /= iterationAcc
        print('Errors:', acc, dataAcc)
        lossHistory.append(acc) 
        lossPositions.append(iteration)
        iterationAcc = 0
        acc = 0
        dataAcc = 0 
        vizBatchSize = 16

        net = net.eval()
        out = net(blurred)
        net = net.train()
        
        out = out.data
        out.clamp_(0, 1.0)
        
        out = out.cpu().numpy()
        print(out.min(), out.max())
        
        fig = plt.figure(figsize=(14, 10), dpi=80, facecolor='w', edgecolor='k')
        plt.subplot(1, 2, 1)
        recColl = collage(out)
        plt.imshow(recColl[:512,:512,::-1])
        
        b = (blurred.data.shape[2] - out.shape[2]) / 2
        if b > 0:
            blurred = blurred[:, :, b:-b, b:-b]
        blurred = blurred.data.cpu().numpy()

        plt.subplot(1, 2, 2)
        blurColl = collage(blurred)
        plt.imshow(blurColl[:512,:512,::-1])
        
        plt.show()

        plt.subplot(1, 1, 1)

        plt.semilogy(lossPositions, lossHistory)
        plt.show()
        np.savez('loss_2.npz', lossPositions, lossHistory)
        cv2.imwrite('{:06d}_blur.png'.format(iteration), blurColl*256)
        cv2.imwrite('{:06d}_rec.png'.format(iteration), recColl*256)
        torch.save(net.state_dict(), 'model_{:06d}.mod'.format(iteration))
        


In [None]:
net = net.eval()
%timeit out = net(blurred)
net = net.train()


In [None]:
plt.plot(lossPositions[1:], lossHistory[1:])
print(iteration, iterationAcc)
l0 = net.layers[0]
l0 = l0.weight.data.cpu().numpy()
l0 -= l0.min()
l0 /= l0.max()

img = collage(l0)
#plt.imshow(img[:,:,::-1])


In [None]:

data = net.psfEmbed.weight.data.cpu()
data -= data.min()
data /= data.max()
data = data.numpy()
data = data.reshape(data.shape[0]*net.filterCount, 3, net.filterSize, net.filterSize)
print(data.shape)
img = collage(data)
plt.imshow(img)
plt.draw()


In [None]:
res = out.data.cpu().numpy()
img = collage(res)
print(img.shape, res.shape, data.shape)
plt.subplot(1, 2, 1)
plt.imshow(img/255.0)
plt.show()

In [None]:
psfIdx = Variable(Tidx.cuda())
f = net.psfEmbed(psfIdx)
f = f.view(f.data.shape[0], f.data.shape[2])
f = F.relu(net.psfFC1(f))
f = net.psfFC2(f)
f = F.softmax(f)
f = f.view(128, 3, net.filterSize, net.filterSize)

res = f.data.cpu().numpy()
plt.subplot(1, 2, 1)
img = collage(res[:16,:,:,:], True)
plt.imshow(img)
plt.show()

In [None]:
res.shape

In [None]:
dataSource.