In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import os,sys
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import time

from helper_functions import *

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
NBR_EPOCHS = 2
BATCH_SIZE = 10
LEARNING_RATE = 1e-3
DIM = 1

In [None]:
%%time
from torch.utils.data import Dataset, DataLoader
class imagesDataset(Dataset): 
    
    def __init__(self):
        imgs_init, gt_imgs_init = load_train_dataset()
        
        # Data augmentation
        self.imgs = compose_all_functions_for_data(imgs_init)
        self.gt_imgs = compose_all_functions_for_data(gt_imgs_init)
        self.n_samples = len(self.imgs)
        
    def __getitem__(self, index):
        return self.imgs[index], self.gt_imgs[index]

    def __len__(self):
        return self.n_samples


# create dataset
torch.manual_seed(1)
dataset = imagesDataset()
train_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

In [None]:
img_temp = dataset[5432][0]
to_PIL = T.ToPILImage()
to_PIL(img_temp)

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
        # 
        self.pool_d = nn.MaxPool2d(2, 2)
        self.pool_u = nn.Upsample(scale_factor=2)
        
        # Activation function
        self.activ = nn.ReLU()
        self.final_activ = nn.Sigmoid()
        
        # Convolution Downwards
        self.conv_1 = nn.Conv2d(3, 64, (3,3), padding=(1, 1))
        self.conv_2 = nn.Conv2d(64, 64, (3,3), padding=(1, 1))
        
        self.conv_3 = nn.Conv2d(64, 128, (3,3), padding=(1, 1))
        self.conv_4 = nn.Conv2d(128, 128, (3,3), padding=(1, 1))
        
        self.conv_5 = nn.Conv2d(128, 256, (3,3), padding=(1, 1))
        self.conv_6 = nn.Conv2d(256, 256, (3,3), padding=(1, 1))
        
        self.conv_7 = nn.Conv2d(256, 512, (3,3), padding=(1, 1))
        self.conv_8 = nn.Conv2d(512, 512, (3,3), padding=(1, 1))
        
        self.conv_9 = nn.Conv2d(512, 1024, (3,3), padding=(1, 1))
        self.conv_10 = nn.Conv2d(1024, 1024, (3,3), padding=(1, 1))
        
        
        # Upconvolution
        self.upconv_1 = nn.Conv2d(512+1024, 512, (3,3), padding=(1, 1))
        self.upconv_2 = nn.Conv2d(512, 512, (3,3), padding=(1, 1))
        
        self.upconv_3 = nn.Conv2d(256+512, 256, (3,3), padding=(1, 1))
        self.upconv_4 = nn.Conv2d(256, 256, (3,3), padding=(1, 1))
        
        self.upconv_5 = nn.Conv2d(128+256, 128, (3,3), padding=(1, 1))
        self.upconv_6 = nn.Conv2d(128, 128, (3,3), padding=(1, 1))
        
        self.upconv_7 = nn.Conv2d(64+128, 64, (3,3), padding=(1, 1))
        self.upconv_8 = nn.Conv2d(64, 64, (3,3), padding=(1, 1))
        self.upconv_9 = nn.Conv2d(64, 2, (1,1))


    def forward(self, x):
        # Convolution with activation and max_pooling
        xd_1 = self.activ(self.conv_1(x))
        xd_2 = self.activ(self.conv_2(xd_1))
    
        xd_3 = self.activ(self.conv_3(self.pool_d(xd_2)))
        xd_4 = self.activ(self.conv_4(xd_3))
        
        xd_5 = self.activ(self.conv_5(self.pool_d(xd_4)))
        xd_6 = self.activ(self.conv_6(xd_5))
        
        xd_7 = self.activ(self.conv_7(self.pool_d(xd_6)))
        xd_8 = self.activ(self.conv_8(xd_7))
        
        xd_9 = self.activ(self.conv_9(self.pool_d(xd_8)))
        xd_10 = self.pool_u(self.activ(self.conv_10(xd_9)))

        # "Fractionally / Backward strided convolution" with activation and upsampling
        xu_1 = self.activ(self.upconv_1(torch.cat((xd_8, xd_10), dim=DIM)))
        xu_2 = self.pool_u(self.activ(self.upconv_2(xu_1)))
        
        xu_3 = self.activ(self.upconv_3(torch.cat((xd_6, xu_2), dim=DIM)))
        xu_4 = self.pool_u(self.activ(self.upconv_4(xu_3)))
        
        xu_5 = self.activ(self.upconv_5(torch.cat((xd_4, xu_4), dim=DIM)))
        xu_6 = self.pool_u(self.activ(self.upconv_6(xu_5)))
        
        xu_7 = self.activ(self.upconv_7(torch.cat((xd_2, xu_6), dim=DIM)))
        xu_8 = self.activ(self.upconv_8(xu_7))
        xu_9 = self.final_activ(self.upconv_9(xu_8))
    
        return xu_9

model = ConvNet().to(device)

In [None]:
#criterion = nn.CrossEntropyLoss()# nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
#loss function used in our neural network
criterion = nn.CrossEntropyLoss()

#criterion = nn.BCEWithLogitsLoss()

In [None]:
%%time

n_total_steps = len(train_loader)
epoch = 0
while True:
    for i, (images, labels) in enumerate(train_loader):
        # Measure training time of one batch sample
        start = time.time()
        
        # FORWARD PASS
        predictions = model(images)
        print(predictions.shape)
        print(labels.shape)
        loss = criterion(predictions, labels)

        # BACKWARD PASS
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch},  Batch {i} - Duration: {time.time()-start}, Loss:{loss.item():.4f}")
        epoch+=1

print('Finished Training')

In [20]:
a = torch.randint(0, 1, size = (10, 2, 400, 400))
b = torch.argmax(a, dim=1)
print(b.shape)

torch.Size([2, 400, 400])
