In [1]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils,models
import cv2
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import multivariate_normal
from scipy import random, linalg
from sklearn.model_selection import train_test_split
import torch.optim as optim
import re
import json
import time
from tqdm import tqdm_notebook

import sys
from utils import *

In [2]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

class CGData(Dataset):
    """Dataset containing all sequences with artifacts
    
    Generates three distorted images as input data.
    
    """
    
    def __init__(self, root_dir, indices, sample_size, resize=True):
        self.root_dir = root_dir
        self.ToPIL = transforms.ToPILImage()
        self.ToTensor = transforms.ToTensor()
        self.indices = indices
        self.resize=resize
        self.sample_size = sample_size
        
        files = os.listdir(root_dir)
        match = lambda x: len(re.findall("img_\d+_\d.jpg", x))== 1
        cut_string = lambda x: eval(re.sub("_.*","",re.sub("img_","",x)))

        files = list(filter(match,files))
        files = list(map(cut_string,files))


        first,last = min(files),max(files)
        self.offset = first
        self.last = last
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):        
        
        idx = self.indices[idx]
        count = 0
        img_files = None
        imgs = None
        label = None
        while True:
            
            n = self.sample_size
            nrs = np.random.choice(range(1,10), size=n, replace=False).tolist()
            img_files = [self.root_dir +  "img_" +str(idx)+ "_" + str(nr) + ".jpg" for nr in nrs]
            exists = all([os.path.isfile(img_file) for img_file in img_files])
            count+=1
            try:
                imgs = [cv2.imread(file) for file in img_files]
                imgs = [img[...,::-1]- np.zeros_like(img) for img in imgs]

                label_file = self.root_dir + "books/img " + "("+str(idx - 1)+").jpg"
                label = cv2.imread(label_file)
                label = label[...,::-1]- np.zeros_like(label)
                break

            except:
                idx = np.random.randint(len(self.indices))
                idx = self.indices[idx]

        
        
        if self.resize:
            label = cv2.resize(label, dsize=(256,256))
            imgs = [ cv2.resize(img, dsize=(256,256)) for img in imgs]
        

        H,W,C = imgs[0].shape
        if H<W:
            label = np.rot90(label)
            label -= np.zeros_like(label)
            imgs = [np.rot90(img) for img in imgs]- np.zeros_like(label)
        
        flip = np.random.randint(-1,3)
        if flip < 2:
            label = cv2.flip(label,flip)- np.zeros_like(label)
            imgs = [cv2.flip(img,flip) for img in imgs]- np.zeros_like(label)

        
        imgs = [self.ToTensor(img) for img in imgs]
        #data = torch.unsqueeze(data,0)
        imgs = torch.stack(imgs)
        
        label = label.astype(np.uint8)
        label = self.ToTensor(label)
        #label = torch.unsqueeze(label,0)
        return imgs, label

# Architecture

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += x
        return out
    
class DilatedResidualBlock(nn.Module):
    def __init__(self, in_planes, out_planes, dilation):
        super(DilatedResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, dilation=dilation, padding=dilation, bias=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, dilation=dilation,padding=dilation, bias=True)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += x
        return out
        
class DeepSetNet(nn.Module):
    """ Deep Set Residual Neural Network """
    def __init__(self, encoder_num_blocks=10, decoder_num_blocks=10, smooth_num_blocks=6, planes=32,block=ResidualBlock ):
        super(DeepSetNet, self).__init__()
        self.planes = planes
        self.input = nn.Conv2d(3, self.planes, kernel_size=3, stride=1, padding=1, bias=True)
        self.output= nn.Conv2d(self.planes, 3, kernel_size=3, stride=1, padding=1, bias=True)
        
        
        # Create a down-/up-sampling architecture
        self.downsample = []
        self.upsample = []
        n = planes
        for i in range(2):
            self.downsample.append( nn.Conv2d(in_channels = n, out_channels=n*2, kernel_size=3, stride=2, padding=1 ) )
            self.downsample.append(nn.ReLU(inplace=True))

            
            self.upsample = [nn.ReLU(inplace=True)] + self.upsample
            self.upsample = [nn.ConvTranspose2d(in_channels=n*2, out_channels=n, kernel_size=3, stride=2, padding=1, output_padding=1)] + self.upsample
            n *= 2

        self.downsample = nn.Sequential(*self.downsample)
        self.upsample = nn.Sequential(*self.upsample)
        
        
        # Embedding of downsampled features
        self.encoder = self._make_layer(block, n, encoder_num_blocks)
        self.decoder = self._make_layer(block, n, decoder_num_blocks)
        self.smooth  = self._make_smooth_layer(planes, smooth_num_blocks)
        
    def _make_layer(self, block, planes, num_blocks):
        layers = []
        for i in range(num_blocks):
            layers.append(DilatedResidualBlock(planes, planes,2))
        return nn.Sequential(*layers)
    
    def _make_smooth_layer(self, planes, num_blocks):
        layers = []
        dilation = 1
        for i in range(num_blocks):
            layers.append(DilatedResidualBlock(planes,planes,dilation))
            if i%2 == 0:
                dilation *= 2
        layers.append( nn.Conv2d(in_channels = planes, out_channels=planes, kernel_size=3, stride=1, padding=1 ) )
        layers.append(nn.ReLU(inplace=True))
        layers.append( nn.Conv2d(in_channels = planes, out_channels=planes, kernel_size=3, stride=1, padding=1 ) )
        return nn.Sequential(*layers)
            
        

    def forward(self, x):
        """Forward pass of our DeepSet Network 
        
        x: of tensor of size (B, S, C, H, W)
        """

        xs = torch.split(x,1,dim = 1)
        xs = [torch.squeeze(x,dim=1) for x in xs]
        embedding = [self.encoder(self.downsample(self.input(x))) for x in xs]
        embedding = torch.stack(embedding).mean(0)
        out = self.output(self.smooth(self.upsample(self.decoder(embedding))))

        
        return out

In [4]:
model_dir = "./models/"
if not os.path.isdir(model_dir):
    os.mkdir(model_dir)

print("The following directory will be used in all further steps:  " + model_dir)


# Create Dataset and split it into Training and test set
minibatch_size = 10
data_dir = "D:/250x250/"
files = os.listdir(data_dir)
match = lambda x: len(re.findall("img_\d+_\d.jpg", x))== 1
cut_string = lambda x: eval(re.sub("_.*","",re.sub("img_","",x)))

files = list(filter(match,files))
files = list(map(cut_string,files))


first, last = min(files),max(files)
print(first, last)

n = last - first + 1
train, test = train_test_split(list(range(first, last+1)))
#train, test = train_test_split(range(first, first+100))

if os.path.isfile(model_dir + "/trainingIdx.txt"):
    f1 = open(model_dir + "/trainingIdx.txt", "r")
    f2 = open(model_dir + "/testIdx.txt", "r")
    train = eval(f1.read())
    test = eval(f2.read())
    f1.close()
    f2.close()
else:
    write(model_dir + "/trainingIdx.txt",train)
    write(model_dir + "/testIdx.txt",test)

    
if os.path.isfile(model_dir + "/params.json"):
    f = open(model_dir + "/params.json", "r")
    params = json.loads(f.read())
    f.close()
else:
    params = {
        "epoch": 0,
        "time": 0,
    }
    f = open(model_dir + "/params.json", "w")
    f.write(json.dumps(params))
    f.close()


device = torch.device("cuda:0")

# Initialize Neural Network
net = DeepSetNet(encoder_num_blocks=10, decoder_num_blocks=5, planes=18)

if  os.path.isfile(model_dir + "/nn.pt"):
    net.load_state_dict(torch.load(model_dir + "/nn.pt"))

net = net.to(device)

criterion = nn.MSELoss()
epochs = 10000

The following directory will be used in all further steps:  ./models/
6462 126632


In [None]:
optimizer = optim.Adam(net.parameters(),lr = 0.0001,)
optimizer.zero_grad()

def train_or_evaluate(net, data, optimizer, n_samples, mode="train"):
    
    dataset = CGData(data_dir,data,n_samples)
    dataloader = DataLoader(dataset, batch_size=minibatch_size,
                         shuffle=False, num_workers=0)
    error = 0.0
    samples = 0
    if mode == "train":
        net.train()
        for x,y in tqdm_notebook(dataloader, desc ="Training on size "+str(n_samples)):
                #x,y = train_set[i]
                x = x.to(device)
                y = y.to(device)
                output = net.forward(x)
                loss = criterion(output,y)
                error += loss.item()
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                samples += 1
        error =  error/samples
        print("Sample size : %d \t Training Error %f"%(n_samples, error))

    else:
        net.eval()
        with torch.no_grad():
            for x,y in tqdm_notebook(dataloader, desc ="Evaluating on size "+str(n_samples)):
                #x,y = train_set[i]
                x = x.to(device)
                y = y.to(device)
                output = net.forward(x)
                loss = criterion(output,y)
                error +=loss.item()
                samples += 1
        error =  error/samples
        print("Sample size : %d \t Test Error %f"%(n_samples, error))
        
    return error


def training(net, train_ids, test_ids,
             model_dir,optimizer,
             params,
             criterion= nn.MSELoss(),
             epochs=1000,
             device=torch.device("cuda:0")):
    
    last_epoch = params["epoch"]
    t = params["time"]

    
    train_loss = 0
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.1)
    
    # This simulates the Learning rate updates
    for _ in range(last_epoch):
        scheduler.step()
    

    for epoch in range(last_epoch+1,epochs):
        
        start = time.time()
        samples = 0
        train_loss = 0
        
        train_ids = list(train_ids)
        test_ids = list(test_ids)
        np.random.shuffle(train_ids)
        np.random.shuffle(test_ids)
        

        for i in range(1,10):
            train_error = train_or_evaluate(net, train_ids, optimizer, i, mode="train")
            end = time.time()
            t += (end - start)
            write(model_dir + "/train_errors.csv",[epoch,train_error,t])
            torch.save(net.state_dict(), model_dir + "/nn"+str(i)+".pt")
            torch.save(net.state_dict(), model_dir + "/nn.pt")

        
        for i in range(1,10):
            test_error = train_or_evaluate(net, test_ids, optimizer, i, mode="test")
            end = time.time()
            t += (end - start)
            write(model_dir + "/test_errors.csv",[epoch,test_error,t])

        torch.save(net.state_dict(), model_dir + "/nn.pt")
        
        #torch.save(net.state_dict(), model_dir + "/nn"+ str(epoch) +".pt")

        params["epoch"] = epoch
        params["time"] = t
        f = open(model_dir + "/params.json", "w")
        f.write(json.dumps(params))
        f.close()
        #if(train_loss>2*best_loss):
            #net.load_state_dict(torch.load(model_dir + "/nn.pt"))
        scheduler.step()


training(net, train, test,model_dir,optimizer,params,criterion)

HBox(children=(IntProgress(value=0, description='Training on size 1', max=9013, style=ProgressStyle(descriptio…


Sample size : 1 	 Training Error 0.009331


HBox(children=(IntProgress(value=0, description='Training on size 2', max=9013, style=ProgressStyle(descriptio…


Sample size : 2 	 Training Error 0.003665


HBox(children=(IntProgress(value=0, description='Training on size 3', max=9013, style=ProgressStyle(descriptio…


Sample size : 3 	 Training Error 0.002525


HBox(children=(IntProgress(value=0, description='Training on size 4', max=9013, style=ProgressStyle(descriptio…


Sample size : 4 	 Training Error 0.002075


HBox(children=(IntProgress(value=0, description='Training on size 5', max=9013, style=ProgressStyle(descriptio…


Sample size : 5 	 Training Error 0.001840


HBox(children=(IntProgress(value=0, description='Training on size 6', max=9013, style=ProgressStyle(descriptio…


Sample size : 6 	 Training Error 0.001691


HBox(children=(IntProgress(value=0, description='Training on size 7', max=9013, style=ProgressStyle(descriptio…


Sample size : 7 	 Training Error 0.001579


HBox(children=(IntProgress(value=0, description='Training on size 8', max=9013, style=ProgressStyle(descriptio…


Sample size : 8 	 Training Error 0.001497


HBox(children=(IntProgress(value=0, description='Training on size 9', max=9013, style=ProgressStyle(descriptio…


Sample size : 9 	 Training Error 0.001429


HBox(children=(IntProgress(value=0, description='Evaluating on size 1', max=3005, style=ProgressStyle(descript…


Sample size : 1 	 Test Error 0.026583


HBox(children=(IntProgress(value=0, description='Evaluating on size 2', max=3005, style=ProgressStyle(descript…


Sample size : 2 	 Test Error 0.006222


HBox(children=(IntProgress(value=0, description='Evaluating on size 3', max=3005, style=ProgressStyle(descript…


Sample size : 3 	 Test Error 0.003004


HBox(children=(IntProgress(value=0, description='Evaluating on size 4', max=3005, style=ProgressStyle(descript…


Sample size : 4 	 Test Error 0.002192


HBox(children=(IntProgress(value=0, description='Evaluating on size 5', max=3005, style=ProgressStyle(descript…


Sample size : 5 	 Test Error 0.001857


HBox(children=(IntProgress(value=0, description='Evaluating on size 6', max=3005, style=ProgressStyle(descript…


Sample size : 6 	 Test Error 0.001683


HBox(children=(IntProgress(value=0, description='Evaluating on size 7', max=3005, style=ProgressStyle(descript…


Sample size : 7 	 Test Error 0.001563


HBox(children=(IntProgress(value=0, description='Evaluating on size 8', max=3005, style=ProgressStyle(descript…


Sample size : 8 	 Test Error 0.001489


HBox(children=(IntProgress(value=0, description='Evaluating on size 9', max=3005, style=ProgressStyle(descript…


Sample size : 9 	 Test Error 0.001435


HBox(children=(IntProgress(value=0, description='Training on size 1', max=9013, style=ProgressStyle(descriptio…


Sample size : 1 	 Training Error 0.009556


HBox(children=(IntProgress(value=0, description='Training on size 2', max=9013, style=ProgressStyle(descriptio…


Sample size : 2 	 Training Error 0.003543


HBox(children=(IntProgress(value=0, description='Training on size 3', max=9013, style=ProgressStyle(descriptio…


Sample size : 3 	 Training Error 0.002364


HBox(children=(IntProgress(value=0, description='Training on size 4', max=9013, style=ProgressStyle(descriptio…


Sample size : 4 	 Training Error 0.001940


HBox(children=(IntProgress(value=0, description='Training on size 5', max=9013, style=ProgressStyle(descriptio…


Sample size : 5 	 Training Error 0.001732


HBox(children=(IntProgress(value=0, description='Training on size 6', max=9013, style=ProgressStyle(descriptio…


Sample size : 6 	 Training Error 0.001604


HBox(children=(IntProgress(value=0, description='Training on size 7', max=9013, style=ProgressStyle(descriptio…


Sample size : 7 	 Training Error 0.001510


HBox(children=(IntProgress(value=0, description='Training on size 8', max=9013, style=ProgressStyle(descriptio…


Sample size : 8 	 Training Error 0.001439


HBox(children=(IntProgress(value=0, description='Training on size 9', max=9013, style=ProgressStyle(descriptio…


Sample size : 9 	 Training Error 0.001388


HBox(children=(IntProgress(value=0, description='Evaluating on size 1', max=3005, style=ProgressStyle(descript…


Sample size : 1 	 Test Error 0.025479


HBox(children=(IntProgress(value=0, description='Evaluating on size 2', max=3005, style=ProgressStyle(descript…


Sample size : 2 	 Test Error 0.005864


HBox(children=(IntProgress(value=0, description='Evaluating on size 3', max=3005, style=ProgressStyle(descript…


Sample size : 3 	 Test Error 0.002892


HBox(children=(IntProgress(value=0, description='Evaluating on size 4', max=3005, style=ProgressStyle(descript…


Sample size : 4 	 Test Error 0.002125


HBox(children=(IntProgress(value=0, description='Evaluating on size 5', max=3005, style=ProgressStyle(descript…


Sample size : 5 	 Test Error 0.001807


HBox(children=(IntProgress(value=0, description='Evaluating on size 6', max=3005, style=ProgressStyle(descript…


Sample size : 6 	 Test Error 0.001634


HBox(children=(IntProgress(value=0, description='Evaluating on size 7', max=3005, style=ProgressStyle(descript…

In [None]:
trainset = CGData(data_dir,train,9)
dataloader = DataLoader(trainset, batch_size=minibatch_size,
                         shuffle=False, num_workers=0)

In [None]:
x,y = trainset[10]
x = x.unsqueeze(0)
xs = torch.split(x,1,dim = 1)
xs = [torch.squeeze(x,dim=1) for x in xs]
[print(x.size()) for x in xs]
y = y.unsqueeze(0)
plot(xs)
plot([y])