In [1]:
import torch
from torch import optim
import torch.nn as nn
from torch.nn import MSELoss
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import torchvision.transforms as transforms
from torchvision.io import read_image
import numpy as np
from numpy import random
import matplotlib.pyplot as plt
import os
from os import listdir
from os.path import splitext
from glob import glob
from PIL import Image
from tqdm import tqdm
import logging
from unet_model import UNet

In [2]:
def imshow(img):
    import cv2
    import IPython
    _,ret = cv2.imencode('.jpg', img) 
    i = IPython.display.Image(data=ret)
    IPython.display.display(i)

In [3]:
class TheDataset(Dataset):
    def __init__(self, interlaced_dir, gtruth_dir, scale=1):
        self.interlaced_dir = interlaced_dir
        self.gtruth_dir = gtruth_dir
        self.scale = scale
        self.transform = transforms.Compose([transforms.RandomCrop(256), transforms.ToTensor()])
        
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'

        self.ids = [splitext(file)[0] for file in listdir(interlaced_dir)]
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img, scale):
        pil_img = pil_img
        
        return pil_img

    def __getitem__(self, i):
        idx = self.ids[i]
        gtruth_file = glob(self.gtruth_dir + idx + '.*')
        interlaced_file = glob(self.interlaced_dir + idx + '.*')

        assert len(gtruth_file) == 1, \
            f'Either no mask or multiple masks found for the ID {idx}: {gtruth_file}'
        assert len(interlaced_file) == 1, \
            f'Either no image or multiple images found for the ID {idx}: {interlaced_file}'
        gtruth = Image.open(gtruth_file[0])
        interlaced = Image.open(interlaced_file[0])

        assert interlaced.size == gtruth.size, \
            f'Image and mask {idx} should be the same size, but are {interlaced.size} and {gtruth.size}'
        
        seed = np.random.randint(2147483647) # make a seed with numpy generator 
        
        random.seed(seed) # apply this seed to img tranfsorms
        torch.manual_seed(seed) # needed for torchvision 0.7
        interlaced = self.transform(interlaced)
        interlaced = self.preprocess(interlaced, self.scale)
        
        random.seed(seed) # apply this seed to img tranfsorms
        torch.manual_seed(seed) # needed for torchvision 0.7
        gtruth = self.transform(gtruth)
        gtruth = self.preprocess(gtruth, self.scale)

        return interlaced, gtruth
    

In [4]:
interlaced_dir = "./dataset/interlaced/"
gtruth_dir = "./dataset/ground_truth/"
img_scale = 1

In [5]:
batch_size = 192
val_percent = 0.1

random.seed(23)
torch.manual_seed(23)

dataset = TheDataset(interlaced_dir, gtruth_dir, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True, drop_last=True)
# for inter, truth in tqdm(train_loader):

#     print(np.array(inter[0].permute(1,2,0)))
#     imshow(np.array(inter[0].permute(1,2,0)))
#     imshow(np.array(truth[0].permute(1,2,0)))
#     break

In [6]:
def train_net(idx, net, lr):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-3)
    criterion = MSELoss()
    
    net.train()
    epoch_loss = 0

    ##################### TRAINING LOOP ########################
    
    for batch, (interlaced, truths) in enumerate(tqdm(train_loader)):
        interlaced = interlaced.to(device=device, dtype=torch.float32)
        truths = truths.to(device=device, dtype=torch.float32)

        net_pred = net(interlaced)
        loss = criterion(net_pred, truths)
        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 69 == 0:
            loss, current = loss.item(), batch * len(interlaced)
            print(f"loss: {loss}  [{current}/{n_train}]")

    print(f"Epoch {idx+1} loss: {epoch_loss/len(train_loader)}-------------------\n")

    ##################### VALIDATION LOOP ########################
    
    test_loss = 0
    with torch.no_grad():
        net.eval()
        for batch, (interlaced, truths) in enumerate(tqdm(val_loader)):

            interlaced = interlaced.to(device=device, dtype=torch.float32)
            truths = truths.to(device=device, dtype=torch.float32)

            net_pred = net(interlaced)
            test_loss += criterion(net_pred, truths).item()

    print(f"Test loss for epoch {idx+1}: {test_loss/len(val_loader)}-------------------\n")

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = UNet(n_channels=3, bilinear=True)
net = nn.DataParallel(net)
net.to(device=device)
lr_list = [1e-3, 1e-4, 1e-5, 1e-6]
for idx, lr in enumerate(lr_list):
    print(f"Epoch {idx+1}\n-------------------------------")
    train_net(idx=idx, net=net, lr=lr)
    torch.save(net, f'model-{idx+1}.pth')
    

  0%|          | 0/443 [00:00<?, ?it/s]

Epoch 1
-------------------------------


  0%|          | 1/443 [01:55<14:12:08, 115.67s/it]

loss: 0.5983479022979736  [0/84924]


 16%|█▌        | 70/443 [09:57<44:03,  7.09s/it]   

loss: 0.00219259993173182  [13248/84924]


 31%|███▏      | 139/443 [17:48<18:22,  3.63s/it]  

loss: 0.0012196513125672936  [26496/84924]


 47%|████▋     | 208/443 [25:37<12:04,  3.08s/it]  

loss: 0.0011065284488722682  [39744/84924]


 63%|██████▎   | 277/443 [34:55<24:13,  8.76s/it]  

loss: 0.0010804092744365335  [52992/84924]


 78%|███████▊  | 346/443 [43:01<06:23,  3.96s/it]  

loss: 0.0008330150158144534  [66240/84924]


 94%|█████████▎| 415/443 [50:50<01:28,  3.15s/it]

loss: 0.0012177982134744525  [79488/84924]


100%|██████████| 443/443 [53:58<00:00,  7.31s/it]
  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 1 loss: 0.006072108653262594-------------------



100%|██████████| 49/49 [06:28<00:00,  7.93s/it]  
  0%|          | 0/443 [00:00<?, ?it/s]

Test loss for epoch 1: 0.0011211410639997647-------------------

Epoch 2
-------------------------------


  0%|          | 1/443 [02:00<14:49:46, 120.78s/it]

loss: 0.0013671801425516605  [0/84924]


 16%|█▌        | 70/443 [09:58<42:05,  6.77s/it]   

loss: 0.000753008876927197  [13248/84924]


 31%|███▏      | 139/443 [17:59<18:15,  3.61s/it]  

loss: 0.0009371443884447217  [26496/84924]


 47%|████▋     | 208/443 [25:57<12:06,  3.09s/it]  

loss: 0.0010625568684190512  [39744/84924]


 63%|██████▎   | 277/443 [35:22<23:51,  8.62s/it]  

loss: 0.0007109579164534807  [52992/84924]


 78%|███████▊  | 346/443 [43:20<06:17,  3.89s/it]  

loss: 0.000679807853884995  [66240/84924]


 94%|█████████▎| 415/443 [51:11<01:27,  3.12s/it]

loss: 0.0007141140522435308  [79488/84924]


100%|██████████| 443/443 [54:06<00:00,  7.33s/it]
  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 2 loss: 0.0008509693187440726-------------------



100%|██████████| 49/49 [06:35<00:00,  8.07s/it]  
  0%|          | 0/443 [00:00<?, ?it/s]

Test loss for epoch 2: 0.0028256699799236898-------------------

Epoch 3
-------------------------------


  0%|          | 1/443 [02:00<14:44:36, 120.08s/it]

loss: 0.0008871853351593018  [0/84924]


 16%|█▌        | 70/443 [09:40<41:30,  6.68s/it]   

loss: 0.0008345170645043254  [13248/84924]


 31%|███▏      | 139/443 [17:35<18:25,  3.64s/it]  

loss: 0.0008171935332939029  [26496/84924]


 47%|████▋     | 208/443 [25:38<12:05,  3.09s/it]  

loss: 0.0009973678970709443  [39744/84924]


 63%|██████▎   | 277/443 [34:48<24:46,  8.95s/it]  

loss: 0.0005905684665776789  [52992/84924]


 78%|███████▊  | 346/443 [42:46<06:17,  3.89s/it]

loss: 0.0008436227217316628  [66240/84924]


 94%|█████████▎| 415/443 [50:36<01:27,  3.13s/it]

loss: 0.0007305563194677234  [79488/84924]


100%|██████████| 443/443 [53:41<00:00,  7.27s/it]
  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 3 loss: 0.0006885684444043474-------------------



100%|██████████| 49/49 [06:34<00:00,  8.05s/it]  
  0%|          | 0/443 [00:00<?, ?it/s]

Test loss for epoch 3: 0.0005972195744552479-------------------

Epoch 4
-------------------------------


  0%|          | 1/443 [01:59<14:43:15, 119.90s/it]

loss: 0.0005889901076443493  [0/84924]


 16%|█▌        | 70/443 [09:44<41:32,  6.68s/it]   

loss: 0.0006303992704488337  [13248/84924]


 31%|███▏      | 139/443 [17:36<18:25,  3.64s/it]  

loss: 0.000691577501129359  [26496/84924]


 47%|████▋     | 208/443 [25:28<12:04,  3.08s/it]  

loss: 0.0005744582158513367  [39744/84924]


 63%|██████▎   | 277/443 [34:33<23:23,  8.45s/it]  

loss: 0.0006074088159948587  [52992/84924]


 78%|███████▊  | 346/443 [42:26<06:18,  3.90s/it]  

loss: 0.0006249517318792641  [66240/84924]


 94%|█████████▎| 415/443 [50:20<01:27,  3.13s/it]

loss: 0.0006432349910028279  [79488/84924]


100%|██████████| 443/443 [53:29<00:00,  7.24s/it]
  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 4 loss: 0.0006520167308093236-------------------



100%|██████████| 49/49 [06:40<00:00,  8.17s/it]  


Test loss for epoch 4: 0.0005256581388660992-------------------

