In [None]:
from pathlib import Path
from shutil import copyfile
#
import torch
import torchvision
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
#
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
#
from models import UNet

## Configs

In [None]:
p_data = Path("./data/128_gray")
p_data.exists()
#
p_train = p_data / "train"
p_valid = p_data / "valid"
#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#
freq_print = 10
freq_vis = 1000
#
n_epochs = 10
img_size = 128
lr = 0.01
n_imgs_viz = 4

## Utils

In [None]:
def show_imgs(X, Y, P, scale_factor=2):
    n_cols = 3
    n_rows = X.shape[0]
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols*2, n_rows*2))
    for ax in axs.flatten():
        ax.set_xticks([])
        ax.set_yticks([])
    for row_idx in range(n_rows):
        axs[row_idx][0].imshow(np.array(T.ToPILImage()(X[row_idx])), cmap="gray")
        axs[row_idx][1].imshow(np.array(T.ToPILImage()(Y[row_idx])), cmap="gray")
        axs[row_idx][2].imshow(np.array(T.ToPILImage()(P[row_idx])), cmap="gray")
    plt.tight_layout()
    plt.show()

class DataSet:
    def __init__(self, p_data, transform=None):
        self.p_inputs = p_data / "inputs"
        self.p_labels = p_data / "labels"
        #
        self.inputs = list(self.p_inputs.glob("*.png"))
        self.labels = list(self.p_labels.glob("*.png"))
        #
        assert [i.name for i in self.inputs] == [l.name for l in self.labels]
        #
        self.transform = transform

    def __len__(self):
        return len(self.inputs)


    def __getitem__(self, idx):
        x = Image.open(self.inputs[idx])
        y = Image.open(self.labels[idx])
        
        if self.transform is not None:
            x = self.transform(x)
            y = self.transform(y)
        return x, y

# Action starts here ...

In [None]:
# DATA
transform = T.Compose([
    T.ToTensor(),
    #T.Normalize([0.5], [0.5]),
])
#
ds_train = DataSet(p_train, transform=transform)
ds_valid = DataSet(p_valid, transform=transform)
#
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=16, shuffle=True, num_workers=8)
dl_valid = torch.utils.data.DataLoader(ds_valid, batch_size=16, shuffle=True)

In [None]:
# MODEL
model = UNet(n_class=1)
model = model.to(device)
summary(model, input_size=(1, img_size, img_size))

In [None]:
# Data for visualization of training progress
X_vis, Y_vis = next(iter(dl_train))
X_vis = X_vis[:n_imgs_viz] 
Y_vis = Y_vis[:n_imgs_viz]
#
P_vis = model(X_vis.to(device))
#
# SHOW
show_imgs(X_vis, Y_vis, P_vis)

In [None]:
# TRAIN
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#
results = {'train_losses': [], 'test_losses':[]}
step = 0
for epoch in range(n_epochs):
    for X, Y in dl_train:
        shape = X.shape
        X = X.to(device)
        Y = Y.to(device)
        #
        pred = model(X)
        #
        loss = criterion(pred, Y)
        results["train_losses"].append(loss.item())
        #
        loss.backward() ; optimizer.step() ; optimizer.zero_grad()
        step +=1
        
        if step % freq_print == 0 or step < 20:
            print("Epoch[{}/{}] Step {}, Loss: {:.3f}".format(epoch, n_epochs, step, loss.item()))
        
        if step % freq_vis == 0 or step < 20:
            with torch.no_grad():
                P_vis = model(X_vis.to(device))
            show_imgs(X_vis, Y_vis, P_vis)

In [None]:
X, Y = next(iter(dl_valid))

In [None]:
with torch.no_grad():
    P = model(X.to(device))

In [None]:
show_imgs(X, Y, P)