In [1]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils,models
import cv2
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import multivariate_normal
from scipy import random, linalg
from sklearn.model_selection import train_test_split
import torch.optim as optim
import re
import json
import time
from tqdm import tqdm_notebook

import sys
from utils import *
import networks

In [2]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

class CGData(Dataset):
    """Dataset containing all sequences with artifacts
    
    Generates three distorted images as input data.
    
    """
    
    def __init__(self, root_dir, indices, sample_size, resize=True):
        self.root_dir = root_dir
        self.ToPIL = transforms.ToPILImage()
        self.ToTensor = transforms.ToTensor()
        self.indices = indices
        self.resize=resize
        self.sample_size = sample_size
        
        files = os.listdir(root_dir)
        match = lambda x: len(re.findall("img_\d+_\d.jpg", x))== 1
        cut_string = lambda x: eval(re.sub("_.*","",re.sub("img_","",x)))

        files = list(filter(match,files))
        files = list(map(cut_string,files))


        first,last = min(files),max(files)
        self.offset = first
        self.last = last
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):        
        
        idx = self.indices[idx]
        count = 0
        img_files = None
        imgs = None
        label = None
        while True:
            
            n = self.sample_size
            nrs = np.random.choice(range(1,10), size=n, replace=False).tolist()
            img_files = [self.root_dir +  "img_" +str(idx)+ "_" + str(nr) + ".jpg" for nr in nrs]
            exists = all([os.path.isfile(img_file) for img_file in img_files])
            count+=1
            try:
                imgs = [cv2.imread(file) for file in img_files]
                imgs = [img[...,::-1]- np.zeros_like(img) for img in imgs]

                label_file = self.root_dir + "books/img " + "("+str(idx - 1)+").jpg"
                label = cv2.imread(label_file)
                label = label[...,::-1]- np.zeros_like(label)
                break

            except:
                idx = np.random.randint(len(self.indices))
                idx = self.indices[idx]

        
        
        if self.resize:
            label = cv2.resize(label, dsize=(256,256))
            imgs = [ cv2.resize(img, dsize=(256,256)) for img in imgs]
        

        H,W,C = imgs[0].shape
        if H<W:
            label = np.rot90(label)
            label -= np.zeros_like(label)
            imgs = [np.rot90(img) for img in imgs]- np.zeros_like(label)
        
        flip = np.random.randint(-1,3)
        if flip < 2:
            label = cv2.flip(label,flip)- np.zeros_like(label)
            imgs = [cv2.flip(img,flip) for img in imgs]- np.zeros_like(label)

        
        imgs = [self.ToTensor(img) for img in imgs]
        #data = torch.unsqueeze(data,0)
        imgs = torch.stack(imgs)
        
        label = label.astype(np.uint8)
        label = self.ToTensor(label)
        #label = torch.unsqueeze(label,0)
        return imgs, label

# Architecture

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += x
        return out
    
class DilatedResidualBlock(nn.Module):
    def __init__(self, in_planes, out_planes, dilation):
        super(DilatedResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, dilation=dilation, padding=dilation, bias=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, dilation=dilation,padding=dilation, bias=True)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += x
        return out
        
class DeepSetNet(nn.Module):
    """ Deep Set Residual Neural Network """
    def __init__(self, encoder_num_blocks=10, decoder_num_blocks=10, smooth_num_blocks=6, planes=32,block=ResidualBlock ):
        super(DeepSetNet, self).__init__()
        self.planes = planes
        self.input = nn.Conv2d(3, self.planes, kernel_size=3, stride=1, padding=1, bias=True)
        self.output= nn.Conv2d(self.planes, 3, kernel_size=3, stride=1, padding=1, bias=True)
        
        
        # Create a down-/up-sampling architecture
        self.downsample = []
        self.upsample = []
        n = planes
        for i in range(2):
            self.downsample.append( nn.Conv2d(in_channels = n, out_channels=n*2, kernel_size=3, stride=2, padding=1 ) )
            self.downsample.append(nn.ReLU(inplace=True))

            
            self.upsample = [nn.ReLU(inplace=True)] + self.upsample
            self.upsample = [nn.ConvTranspose2d(in_channels=n*2, out_channels=n, kernel_size=3, stride=2, padding=1, output_padding=1)] + self.upsample
            n *= 2

        self.downsample = nn.Sequential(*self.downsample)
        self.upsample = nn.Sequential(*self.upsample)
        
        
        # Embedding of downsampled features
        self.encoder = self._make_layer(block, n, encoder_num_blocks)
        self.decoder = self._make_layer(block, n, decoder_num_blocks)
        self.smooth  = self._make_smooth_layer(planes, smooth_num_blocks)
        
    def _make_layer(self, block, planes, num_blocks):
        layers = []
        for i in range(num_blocks):
            layers.append(DilatedResidualBlock(planes, planes,2))
        return nn.Sequential(*layers)
    
    def _make_smooth_layer(self, planes, num_blocks):
        layers = []
        dilation = 1
        for i in range(num_blocks):
            layers.append(DilatedResidualBlock(planes,planes,dilation))
            if i%2 == 0:
                dilation *= 2
        layers.append( nn.Conv2d(in_channels = planes, out_channels=planes, kernel_size=3, stride=1, padding=1 ) )
        layers.append(nn.ReLU(inplace=True))
        layers.append( nn.Conv2d(in_channels = planes, out_channels=planes, kernel_size=3, stride=1, padding=1 ) )
        return nn.Sequential(*layers)
            
        

    def forward(self, x):
        """Forward pass of our DeepSet Network 
        
        x: of tensor of size (B, S, C, H, W)
        """

        xs = torch.split(x,1,dim = 1)
        xs = [torch.squeeze(x,dim=1) for x in xs]
        embedding = [self.encoder(self.downsample(self.input(x))) for x in xs]
        embedding = torch.stack(embedding).mean(0)
        out = self.output(self.smooth(self.upsample(self.decoder(embedding))))

        
        return out

In [4]:
model_dir = "./modelsGAN/"
if not os.path.isdir(model_dir):
    os.mkdir(model_dir)

print("The following directory will be used in all further steps:  " + model_dir)


# Create Dataset and split it into Training and test set
minibatch_size = 10
data_dir = "D:/250x250/"
files = os.listdir(data_dir)
match = lambda x: len(re.findall("img_\d+_\d.jpg", x))== 1
cut_string = lambda x: eval(re.sub("_.*","",re.sub("img_","",x)))

files = list(filter(match,files))
files = list(map(cut_string,files))


first, last = min(files),max(files)
print(first, last)

n = last - first + 1
train, test = train_test_split(list(range(first, last+1)))
#train, test = train_test_split(range(first, first+100))

if os.path.isfile(model_dir + "/trainingIdx.txt"):
    f1 = open(model_dir + "/trainingIdx.txt", "r")
    f2 = open(model_dir + "/testIdx.txt", "r")
    train = eval(f1.read())
    test = eval(f2.read())
    f1.close()
    f2.close()
else:
    write(model_dir + "/trainingIdx.txt",train)
    write(model_dir + "/testIdx.txt",test)

    
if os.path.isfile(model_dir + "/params.json"):
    f = open(model_dir + "/params.json", "r")
    params = json.loads(f.read())
    f.close()
else:
    params = {
        "epoch": 0,
        "time": 0,
    }
    f = open(model_dir + "/params.json", "w")
    f.write(json.dumps(params))
    f.close()


device = torch.device("cuda:0")

# Initialize Neural Network
net = DeepSetNet(encoder_num_blocks=10, decoder_num_blocks=5, planes=18)

if  os.path.isfile(model_dir + "/nn.pt"):
    net.load_state_dict(torch.load(model_dir + "/nn.pt"))

net = net.to(device)
D = networks.MultiscaleDiscriminator(3,getIntermFeat=True)
D = D.to(device)

criterion = nn.MSELoss()
epochs = 10000

The following directory will be used in all further steps:  ./modelsGAN/
6462 126632


In [5]:
optimizer = optim.Adam(net.parameters(),lr = 0.0001,)
optimizer.zero_grad()

def train_or_evaluate(data, G, D, 
                      optimizer_G, optimizer_D, Image_pool,
                      L_GAN, L_FM, 
                      n_samples, mode="train"):
    
    dataset = CGData(data_dir,data,n_samples)
    dataloader = DataLoader(dataset, batch_size=minibatch_size,
                         shuffle=False, num_workers=0)
    
    error_G = 0.0
    error_D = 0.0
    error_mae = 0.0
    #error_VGG = 0.0
    
    MAE = nn.L1Loss()
    MSE = nn.MSELoss()
    
    samples = 0
    if mode == "train":
        net.train()
        for x,y in tqdm_notebook(dataloader, desc ="Training on size "+str(n_samples)):
            
                x = x.to(device)
                y = y.to(device)

                # Compute Fake Image using a pretrained ResNet as initial (conditional) solution
                fake_image = G(x)
                
                ####### Optimze Discriminator ###########
                # Optimize for fake input
                fake_query = Image_pool.query(fake_image)
                pred_fake_pool = D(fake_query)
                loss_D_fake = L_GAN(pred_fake_pool,False)       

                # Optimize for real input
                pred_real = D(y)
                loss_D_real = L_GAN(pred_real,True)
                loss_D = 0.5*(loss_D_fake+loss_D_real)

                ####### Optimize Generator ###########
                # Optimize to fool the discriminator (fake passability)
                pred_false_positve = D(fake_image)
                loss_G_GAN = L_GAN(pred_false_positve,True)
                loss_mae = MAE(fake_image,y)
                
                 # GAN feature matching loss
                loss_G_GAN_Feat = 0
                feat_weights = 4.0 / (3.0 + 1.0)
                D_weights = 1.0 / 1.0
                for i in range(3):
                    for j in range(len(pred_false_positve[i])-1):
                        loss_G_GAN_Feat += D_weights * feat_weights * \
                            L_FM(pred_false_positve[i][j], pred_real[i][j].detach()) * 10.0

                #logg_vgg = L_VGG(fake_image,y)*10

                # Compute loss for G and D    
                loss_D = 0.5*(loss_D_fake+loss_D_real)
                loss_G = loss_G_GAN + loss_G_GAN_Feat + loss_mae*10

                error_G += loss_G.item()
                error_D += loss_D.item()
                error_mae += loss_mae.item()


                # Apply BackProp for D
                loss_D.backward()
                optimizer_D.step()

                # Apply BackProp for G
                loss_G.backward()
                optimizer_G.step()
                
                optimizer_D.zero_grad()  
                optimizer_G.zero_grad()
                samples += 1
                
                
        error_G /=samples
        error_D /=samples
        error_mae /= samples
        print("Error Generator: %f | Error Discriminator %f | MAE %f"%(error_G, error_D, error_mae))
        return error_G, error_D, error_mae

    else:
        net.eval()

        
        error_mae = 0
        error_mse = 0

        with torch.no_grad():
            for x,y in tqdm_notebook(dataloader, desc ="Evaluating on size "+str(n_samples)):
                #x,y = train_set[i]
                x = x.to(device)
                y = y.to(device)
                output = net.forward(x)
                mae_loss = MAE(output,y)                
                mse_loss = MSE(output,y)

                error_mae += mae_loss.item()
                error_mse += mse_loss.item()

                samples += 1
                
        error_mae /= samples
        error_mse /= samples
        print("MSE %f | MAE %f"%(error_mse, error_mae))
        return error_mse, error_mae
        
    return error


def training(G, D, train_ids, test_ids,
             model_dir,
             params,
             epochs=1000,
             device=torch.device("cuda:0")):
    
    last_epoch = params["epoch"]
    t = params["time"]

    
    optimizer_D = optim.Adam(D.parameters(), lr = 0.0002,)
    optimizer_D.zero_grad()
    optimizer_G = optim.Adam(G.parameters(), lr = 0.0002,)
    optimizer_G.zero_grad()
    
    
    Image_pool = networks.ImagePool(50)
    L_GAN = networks.GANLoss(use_lsgan=True,tensor = torch.cuda.FloatTensor)
    L_FM = nn.L1Loss()
    
    train_loss = 0
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer_G, 5, 0.1)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer_D, 5, 0.1)

    
    # This simulates the Learning rate updates
    for _ in range(last_epoch):
        scheduler.step()
    

    for epoch in range(last_epoch+1,epochs):
        
        start = time.time()
        samples = 0
        train_loss = 0
        
        train_ids = list(train_ids)
        test_ids = list(test_ids)
        np.random.shuffle(train_ids)
        np.random.shuffle(test_ids)
        

        for i in range(1,9):
            train_error = train_or_evaluate(train_ids, G, D, 
                                            optimizer_G, optimizer_D, Image_pool,
                                            L_GAN, L_FM, i, mode="train")
            end = time.time()
            t += (end - start)
            write(model_dir + "/train_errors.csv",[epoch,train_error,t])
            torch.save(net.state_dict(), model_dir + "/nn"+ str(i)+".pt")

        for i in range(1,10):
            test_error = train_or_evaluate(train_ids, G, D, 
                                            optimizer_G, optimizer_D, Image_pool,
                                            L_GAN, L_FM, i, mode="test")
            end = time.time()
            t += (end - start)
            write(model_dir + "/test_errors.csv",[epoch,test_error,t])

        torch.save(net.state_dict(), model_dir + "/nn.pt")
        
        #torch.save(net.state_dict(), model_dir + "/nn"+ str(epoch) +".pt")

            
        params["epoch"] = epoch
        params["time"] = t
        f = open(model_dir + "/params.json", "w")
        f.write(json.dumps(params))
        f.close()
        #if(train_loss>2*best_loss):
            #net.load_state_dict(torch.load(model_dir + "/nn.pt"))
        scheduler.step()


training(net, D, train, test, model_dir,params)

HBox(children=(IntProgress(value=0, description='Training on size 1', max=9013, style=ProgressStyle(descriptio…


Error Generator: 12.076235 | Error Discriminator 0.575856 | MAE 0.122965


HBox(children=(IntProgress(value=0, description='Training on size 2', max=9013, style=ProgressStyle(descriptio…


Error Generator: 8.088293 | Error Discriminator 0.648131 | MAE 0.071961


HBox(children=(IntProgress(value=0, description='Training on size 3', max=9013, style=ProgressStyle(descriptio…


Error Generator: 7.073338 | Error Discriminator 0.646505 | MAE 0.058283


HBox(children=(IntProgress(value=0, description='Training on size 4', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.678004 | Error Discriminator 0.636489 | MAE 0.053138


HBox(children=(IntProgress(value=0, description='Training on size 5', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.248131 | Error Discriminator 0.655023 | MAE 0.049377


HBox(children=(IntProgress(value=0, description='Training on size 6', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.107376 | Error Discriminator 0.636710 | MAE 0.047256


HBox(children=(IntProgress(value=0, description='Training on size 7', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.937742 | Error Discriminator 0.640997 | MAE 0.045323


HBox(children=(IntProgress(value=0, description='Training on size 8', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.748138 | Error Discriminator 0.638091 | MAE 0.043660


HBox(children=(IntProgress(value=0, description='Evaluating on size 1', max=9013, style=ProgressStyle(descript…


MSE 0.036324 | MAE 0.114361


HBox(children=(IntProgress(value=0, description='Evaluating on size 2', max=9013, style=ProgressStyle(descript…


MSE 0.011167 | MAE 0.066789


HBox(children=(IntProgress(value=0, description='Evaluating on size 3', max=9013, style=ProgressStyle(descript…


MSE 0.006212 | MAE 0.052997


HBox(children=(IntProgress(value=0, description='Evaluating on size 4', max=9013, style=ProgressStyle(descript…


MSE 0.004822 | MAE 0.047631


HBox(children=(IntProgress(value=0, description='Evaluating on size 5', max=9013, style=ProgressStyle(descript…


MSE 0.004289 | MAE 0.045041


HBox(children=(IntProgress(value=0, description='Evaluating on size 6', max=9013, style=ProgressStyle(descript…


MSE 0.003993 | MAE 0.043416


HBox(children=(IntProgress(value=0, description='Evaluating on size 7', max=9013, style=ProgressStyle(descript…


MSE 0.003823 | MAE 0.042399


HBox(children=(IntProgress(value=0, description='Evaluating on size 8', max=9013, style=ProgressStyle(descript…


MSE 0.003699 | MAE 0.041656


HBox(children=(IntProgress(value=0, description='Evaluating on size 9', max=9013, style=ProgressStyle(descript…


MSE 0.003617 | MAE 0.041117


HBox(children=(IntProgress(value=0, description='Training on size 1', max=9013, style=ProgressStyle(descriptio…


Error Generator: 8.546964 | Error Discriminator 0.516376 | MAE 0.080310


HBox(children=(IntProgress(value=0, description='Training on size 2', max=9013, style=ProgressStyle(descriptio…


Error Generator: 7.053150 | Error Discriminator 0.566802 | MAE 0.058785


HBox(children=(IntProgress(value=0, description='Training on size 3', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.453663 | Error Discriminator 0.583140 | MAE 0.051618


HBox(children=(IntProgress(value=0, description='Training on size 4', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.156228 | Error Discriminator 0.581513 | MAE 0.048130


HBox(children=(IntProgress(value=0, description='Training on size 5', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.938443 | Error Discriminator 0.583779 | MAE 0.045692


HBox(children=(IntProgress(value=0, description='Training on size 6', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.803781 | Error Discriminator 0.578164 | MAE 0.044038


HBox(children=(IntProgress(value=0, description='Training on size 7', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.666859 | Error Discriminator 0.579930 | MAE 0.042669


HBox(children=(IntProgress(value=0, description='Training on size 8', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.592813 | Error Discriminator 0.571637 | MAE 0.041811


HBox(children=(IntProgress(value=0, description='Evaluating on size 1', max=9013, style=ProgressStyle(descript…


MSE 0.035566 | MAE 0.113331


HBox(children=(IntProgress(value=0, description='Evaluating on size 2', max=9013, style=ProgressStyle(descript…


MSE 0.011255 | MAE 0.067006


HBox(children=(IntProgress(value=0, description='Evaluating on size 3', max=9013, style=ProgressStyle(descript…


MSE 0.006320 | MAE 0.053129


HBox(children=(IntProgress(value=0, description='Evaluating on size 4', max=9013, style=ProgressStyle(descript…


MSE 0.004792 | MAE 0.047217


HBox(children=(IntProgress(value=0, description='Evaluating on size 5', max=9013, style=ProgressStyle(descript…


MSE 0.004176 | MAE 0.044285


HBox(children=(IntProgress(value=0, description='Evaluating on size 6', max=9013, style=ProgressStyle(descript…


MSE 0.003850 | MAE 0.042497


HBox(children=(IntProgress(value=0, description='Evaluating on size 7', max=9013, style=ProgressStyle(descript…


MSE 0.003665 | MAE 0.041372


HBox(children=(IntProgress(value=0, description='Evaluating on size 8', max=9013, style=ProgressStyle(descript…


MSE 0.003542 | MAE 0.040593


HBox(children=(IntProgress(value=0, description='Evaluating on size 9', max=9013, style=ProgressStyle(descript…


MSE 0.003440 | MAE 0.039960


HBox(children=(IntProgress(value=0, description='Training on size 1', max=9013, style=ProgressStyle(descriptio…


Error Generator: 8.289270 | Error Discriminator 0.457629 | MAE 0.077797


HBox(children=(IntProgress(value=0, description='Training on size 2', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.989344 | Error Discriminator 0.495696 | MAE 0.057448


HBox(children=(IntProgress(value=0, description='Training on size 3', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.429699 | Error Discriminator 0.516210 | MAE 0.050760


HBox(children=(IntProgress(value=0, description='Training on size 4', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.138452 | Error Discriminator 0.521956 | MAE 0.047584


HBox(children=(IntProgress(value=0, description='Training on size 5', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.947866 | Error Discriminator 0.529345 | MAE 0.045321


HBox(children=(IntProgress(value=0, description='Training on size 6', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.831220 | Error Discriminator 0.525456 | MAE 0.043869


HBox(children=(IntProgress(value=0, description='Training on size 7', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.737628 | Error Discriminator 0.530805 | MAE 0.042581


HBox(children=(IntProgress(value=0, description='Training on size 8', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.661214 | Error Discriminator 0.527545 | MAE 0.041768


HBox(children=(IntProgress(value=0, description='Evaluating on size 1', max=9013, style=ProgressStyle(descript…


MSE 0.030771 | MAE 0.104332


HBox(children=(IntProgress(value=0, description='Evaluating on size 2', max=9013, style=ProgressStyle(descript…


MSE 0.009481 | MAE 0.061591


HBox(children=(IntProgress(value=0, description='Evaluating on size 3', max=9013, style=ProgressStyle(descript…


MSE 0.005464 | MAE 0.049570


HBox(children=(IntProgress(value=0, description='Evaluating on size 4', max=9013, style=ProgressStyle(descript…


MSE 0.004359 | MAE 0.045041


HBox(children=(IntProgress(value=0, description='Evaluating on size 5', max=9013, style=ProgressStyle(descript…


MSE 0.003923 | MAE 0.042810


HBox(children=(IntProgress(value=0, description='Evaluating on size 6', max=9013, style=ProgressStyle(descript…


MSE 0.003693 | MAE 0.041522


HBox(children=(IntProgress(value=0, description='Evaluating on size 7', max=9013, style=ProgressStyle(descript…


MSE 0.003542 | MAE 0.040644


HBox(children=(IntProgress(value=0, description='Evaluating on size 8', max=9013, style=ProgressStyle(descript…


MSE 0.003445 | MAE 0.040057


HBox(children=(IntProgress(value=0, description='Evaluating on size 9', max=9013, style=ProgressStyle(descript…


MSE 0.003375 | MAE 0.039631


HBox(children=(IntProgress(value=0, description='Training on size 1', max=9013, style=ProgressStyle(descriptio…


Error Generator: 8.288082 | Error Discriminator 0.412705 | MAE 0.077956


HBox(children=(IntProgress(value=0, description='Training on size 2', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.998844 | Error Discriminator 0.451548 | MAE 0.057194


HBox(children=(IntProgress(value=0, description='Training on size 3', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.486633 | Error Discriminator 0.474071 | MAE 0.050592


HBox(children=(IntProgress(value=0, description='Training on size 4', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.221484 | Error Discriminator 0.482994 | MAE 0.047466


HBox(children=(IntProgress(value=0, description='Training on size 5', max=9013, style=ProgressStyle(descriptio…


Error Generator: 6.028264 | Error Discriminator 0.493083 | MAE 0.045341


HBox(children=(IntProgress(value=0, description='Training on size 6', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.901434 | Error Discriminator 0.497912 | MAE 0.043898


HBox(children=(IntProgress(value=0, description='Training on size 7', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.813686 | Error Discriminator 0.496435 | MAE 0.042761


HBox(children=(IntProgress(value=0, description='Training on size 8', max=9013, style=ProgressStyle(descriptio…


Error Generator: 5.756246 | Error Discriminator 0.493665 | MAE 0.041908


HBox(children=(IntProgress(value=0, description='Evaluating on size 1', max=9013, style=ProgressStyle(descript…


MSE 0.030996 | MAE 0.106877


HBox(children=(IntProgress(value=0, description='Evaluating on size 2', max=9013, style=ProgressStyle(descript…


MSE 0.009987 | MAE 0.063944


HBox(children=(IntProgress(value=0, description='Evaluating on size 3', max=9013, style=ProgressStyle(descript…


MSE 0.005871 | MAE 0.051625


HBox(children=(IntProgress(value=0, description='Evaluating on size 4', max=9013, style=ProgressStyle(descript…


MSE 0.004625 | MAE 0.046529


HBox(children=(IntProgress(value=0, description='Evaluating on size 5', max=9013, style=ProgressStyle(descript…


MSE 0.004116 | MAE 0.043979


HBox(children=(IntProgress(value=0, description='Evaluating on size 6', max=9013, style=ProgressStyle(descript…


MSE 0.003845 | MAE 0.042493


HBox(children=(IntProgress(value=0, description='Evaluating on size 7', max=9013, style=ProgressStyle(descript…


MSE 0.003678 | MAE 0.041482


HBox(children=(IntProgress(value=0, description='Evaluating on size 8', max=9013, style=ProgressStyle(descript…


MSE 0.003562 | MAE 0.040760


HBox(children=(IntProgress(value=0, description='Evaluating on size 9', max=9013, style=ProgressStyle(descript…


MSE 0.003484 | MAE 0.040238


HBox(children=(IntProgress(value=0, description='Training on size 1', max=9013, style=ProgressStyle(descriptio…


Error Generator: 8.312424 | Error Discriminator 0.382640 | MAE 0.077943


HBox(children=(IntProgress(value=0, description='Training on size 2', max=9013, style=ProgressStyle(descriptio…

KeyboardInterrupt: 

In [None]:
trainset = CGData(data_dir,train,1)

In [None]:
x,y = trainset[np.random.randint(0,len(trainset))]
x = x.unsqueeze(0).to(device)
xs = torch.split(x,1,dim = 1)
xs = [torch.squeeze(x,dim=0).cpu() for x in xs]
y = y.unsqueeze(0)
plot(xs+[net(x).cpu()]+[y.cpu()])