In [None]:
from __future__ import print_function
%matplotlib inline  
import matplotlib.pyplot as plt

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np

In [None]:

class SqueezeExitationBlock(nn.Module):
    def __init__(self, inputChannels, filterSize=3, filterCount=32):
        super(SqueezeExitationBlock, self).__init__()
        
        if filterSize > 3:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=0,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)
        else:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=(filterSize-1)/2,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)

        self.sqeeze = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels=filterCount, 
                     out_channels=filterCount, 
                     kernel_size=1),
            nn.Sigmoid())

        self.output = nn.Sequential(
            nn.BatchNorm2d(num_features=filterCount, momentum=0.75),
            nn.ReLU()
            )
        
    def forward(self, input_x):
        act = self.layer1(input_x)
        weights = self.sqeeze(act)
        out = self.output(act * weights)
        return out

class ConvBlock(nn.Module):
    def __init__(self, inputChannels, filterSize=3, filterCount=32):
        super(ConvBlock, self).__init__()
        
        if filterSize > 3:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=0,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)
        else:
            self.layer1 = nn.Conv2d(in_channels=inputChannels, 
                                     padding=(filterSize-1)/2,
                                     out_channels=filterCount, 
                                     kernel_size=filterSize)


        self.output = nn.Sequential(
            nn.BatchNorm2d(num_features=filterCount, momentum=0.75),
            nn.ReLU()
            )
        
    def forward(self, input_x):
        act = self.layer1(input_x)
        out = self.output(act)
        return out
    
class LinearBlock(nn.Module):

    def __init__(self, inputChannels, filterSizes=[3, 3], filterCounts=[32, 32]):
        super(LinearBlock, self).__init__()
        
        lastChannels = inputChannels
        convLayers = []
        for fSize, fCount in zip(filterSizes, filterCounts):
            #if inputChannels > 6:
            #    convLayers.append(torch.nn.Dropout2d(p=0.2))
            convLayers.append(
                ConvBlock(inputChannels=lastChannels, 
                                      filterCount=fCount, 
                                      filterSize=fSize))
            lastChannels = fCount

        self.convLayers = nn.Sequential(*convLayers)
        
    def forward(self, input_x):
        return self.convLayers(input_x)
    
    
    
class AggregationNet(nn.Module):
    def __init__(self, inputChannels, outChannels):
        super(AggregationNet, self).__init__()
        self.compressBlockInfo = [(2, 16), (2, 24), (2, 32)]        
        self.compressBlocks = torch.nn.ModuleList()
   
        filterSizes = [13, 1, 3]    
        filterCounts = [32, 32, 16]
        self.compressBlocks.append(LinearBlock(inputChannels, filterSizes, filterCounts))
        lastChannels = filterCounts[-1]

        for layerCount, filterCount in self.compressBlockInfo[1:]:
            filterSizes = [3 for i in range(layerCount)]    
            filterCounts = [filterCount for i in range(layerCount)]
            self.compressBlocks.append(LinearBlock(lastChannels, filterSizes, filterCounts))
            lastChannels = filterCount
        
        self.outLayers = nn.Sequential(
            nn.Conv2d(in_channels=lastChannels, out_channels=256, kernel_size=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels=256, out_channels=outChannels, kernel_size=1),
            nn.ReLU(),
            );
        
    def forward(self, input_x):
        x = input_x
        for i, module in enumerate(self.compressBlocks):
            x = module(x)
            x = F.max_pool2d(x, 2, stride=2)
        x = self.outLayers(x)
            
        return x
    
class ColorNet(nn.Module):
    def __init__(self, inputChannels, modulationChannels):
        super(ColorNet, self).__init__()
        
        self.filterCount = 16
        self.filterCounts = [self.filterCount, self.filterCount]
        self.layerMultipliers = torch.nn.ModuleList()
        self.layerBN = torch.nn.ModuleList()
        
        for fCount in self.filterCounts:
            self.layerMultipliers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels=modulationChannels, out_channels=fCount, kernel_size=1),
                    nn.Softmax2d(),
                )
            )
            self.layerBN.append(
                 nn.BatchNorm2d(num_features=fCount, momentum=0.75))     
            
        self.layers = torch.nn.ModuleList()
        self.layers.append(
            nn.Conv2d(in_channels=inputChannels*3, out_channels=self.filterCount, kernel_size=1, padding=0))
        self.layers.append(
            nn.Conv2d(in_channels=self.filterCount, out_channels=self.filterCount, kernel_size=1, padding=0))
        
        self.lastLayer = nn.Conv2d(in_channels=self.filterCount, out_channels=inputChannels, kernel_size=1, padding=0)
    
    def forward(self, input_x, modulation):
        x = torch.cat((input_x, torch.pow(input_x, 2), torch.pow(input_x, 3)), dim=1)
        for layer, multiplier, bNorm in zip(self.layers, self.layerMultipliers, self.layerBN):
            x = F.relu((layer(x)) * multiplier(modulation), inplace=True)
            x = bNorm(x)
            
        x = F.tanh(self.lastLayer(x))
        return x       
    
class FilterNet(nn.Module):
    def __init__(self, inputChannels, filterSize):
        super(FilterNet, self).__init__()

        self.filterSize = filterSize
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=inputChannels, out_channels=512, kernel_size=1),
            nn.BatchNorm2d(num_features=512, momentum=0.75),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
            nn.BatchNorm2d(num_features=512, momentum=0.75),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
            nn.BatchNorm2d(num_features=512, momentum=0.75),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=filterSize*filterSize, kernel_size=1))
    
    def forward(self, input_x):
        x = self.layers(input_x)
        x = F.softmax(x)
        return x.view(x.data.shape[0], 1, self.filterSize, self.filterSize)       

class HourGlassNet(nn.Module):

    def __init__(self, inputChannels, aggChannels):
        super(HourGlassNet, self).__init__()
        
        self.compressBlockInfo = [(2, 24), (2, 32), (3, 64)]
        self.decompressBlockInfo = [(2, 32), (2, 24)]
        self.compressBlocks = torch.nn.ModuleList()
        self.decompressBlocks = torch.nn.ModuleList()


        self.blockMultipliers = torch.nn.ModuleList()
        self.blockBN = torch.nn.ModuleList()
        for fCount in [j[1] for j in self.compressBlockInfo + self.decompressBlockInfo]:
            self.blockMultipliers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels=aggChannels, out_channels=fCount, kernel_size=1),
                    nn.Softmax2d(),
                )
            )
            self.blockBN.append(
                 nn.BatchNorm2d(num_features=fCount, momentum=0.75))

        lastChannels = inputChannels

        filterSizes = [13, 1, 3]    
        filterCounts = [64, 32, 24]
        self.compressBlocks.append(LinearBlock(lastChannels, filterSizes, filterCounts))
        lastChannels = filterCounts[-1]

        for layerCount, filterCount in self.compressBlockInfo[1:]:
            filterSizes = [3 for i in range(layerCount)]    
            filterCounts = [filterCount for i in range(layerCount)]
            self.compressBlocks.append(LinearBlock(lastChannels, filterSizes, filterCounts))
            lastChannels = filterCount

        compressFilterCounts = [fc for lc, fc in self.compressBlockInfo[:-1]]

        for compressFC, (layerCount, filterCount) in zip(compressFilterCounts[::-1], self.decompressBlockInfo):
            filterSizes = [3 for i in range(layerCount)]    
            filterCounts = [filterCount for i in range(layerCount)]
            self.decompressBlocks.append(LinearBlock(lastChannels + compressFC, filterSizes, filterCounts))
            lastChannels = filterCount
        
        
        fSize = 3
        self.outputLayer = nn.Conv2d(in_channels=lastChannels, 
                                          padding=(fSize-1)/2,
                                          out_channels=inputChannels, 
                                          kernel_size=fSize)
        
    def forward(self, input_x, agg):
        blockMultipliers = list(self.blockMultipliers)
        blockBN = list(self.blockBN)
        
        x = input_x
        compressStages = []
        for i, module in enumerate(self.compressBlocks):
            x = module(x)
            x = blockBN[0](x * blockMultipliers[0](agg))
            blockMultipliers = blockMultipliers[1:]
            blockBN = blockBN[1:]
            if i < len(self.compressBlocks) -1:
                compressStages.append(x)
                x = F.max_pool2d(x, 2, stride=2)
        
        for module, bypass in zip(self.decompressBlocks, compressStages[::-1]):
            x = F.upsample_nearest(x, scale_factor=2)
            x = torch.cat((x, bypass), dim=1)
            x = module(x)
            x = blockBN[0](x * blockMultipliers[0](agg))
            blockMultipliers = blockMultipliers[1:]
            blockBN = blockBN[1:]
            
        x = self.outputLayer(x)
        return x
    
    
class FullNet(nn.Module):

    def __init__(self, inputChannels, aggregationNet, colorNet, filterNet, reconstructNet):
        super(FullNet, self).__init__()
        
        self.aggregation = aggregationNet
        self.colorNet = colorNet
        self.reconstructNet = reconstructNet
        self.filterNet = filterNet
        
    def forward(self, input_x):
        agg = self.aggregation(input_x)
        color_correction = self.colorNet(input_x, agg)
        psf = self.filterNet(agg)
        
        x = self.reconstructNet(color_correction, agg)

        return x, color_correction, psf       

In [None]:
imgEmbedding = 256
aggNet = AggregationNet(3, imgEmbedding)
colorNet = ColorNet(3, imgEmbedding)
recNet = HourGlassNet(3, imgEmbedding)
filterNet = FilterNet(imgEmbedding, 21)
net = FullNet(3, aggNet, colorNet, filterNet, recNet)
print(net)
net.cuda()

lossHistory = []
lossPositions = []
iteration = 0


In [None]:
net.load_state_dict(torch.load('model_279999.mod'))
net = net.cuda().train()

In [None]:
import cv2


cap = cv2.VideoCapture(0)

print(cap.isOpened())

In [None]:
from time import time
size = 256
while(True):
    ret, frame = cap.read()

    b0 = (frame.shape[0] - size) / 2
    b1 = (frame.shape[1] - size) / 2
    
    crop = frame[b0:b0+size, b1:b1+size]
    data = crop.transpose(2, 0, 1).reshape(1,3,size,size).astype(np.float32)
    t1 = time()
    data = Variable(torch.from_numpy(data).cuda() / 255)
    out, color_correction, t = net(data)
    
    out = out.data.cpu().numpy()[0].transpose(1, 2, 0)
    t2 = time()
    # print( t2-t1)
    color_correction = color_correction.data.cpu().numpy()[0].transpose(1, 2, 0)   

    cv2.imshow('in', crop)
    cv2.imshow('out', out)
    cv2.imshow('color',color_correction)
    key = cv2.waitKey(1) & 0xFF
    if key == 27:
        break


In [None]:
cap.release()

In [None]:
cv2.waitKey()

In [None]:
print(frame.shape)

In [None]:
torch.Tensor(data)