In [None]:
from unet import *
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd
import scipy.io
import pandas as pd
# tensor summary import
from torchsummary import summary

In [None]:
Net = build_unet()
# check if CUDA is available, and set it as the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("==> Device: {}".format(device))

# move model to the device
Net.to(device)

# define loss function
# criterion = nn.L1Loss()
criterion = nn.MSELoss()

####### HYPERPARAMETERS #######

# learning rate = 10^-4
lr = 0.0001

# patch size
patch_size = 512

# define optimizer
optimizer = optim.Adam(Net.parameters(), lr=lr)

# batch size
batch_size = 1

# define number of epochs
n_epochs = 100

# keep track of the best validation loss
valid_loss_min = np.Inf

# number of epochs to wait before stopping
early_stopping = 5


##### LOSS  #######

# keep track of training and validation loss
train_loss = []
valid_loss = []

# initialize the early_stopping object
# early_stopping = EarlyStopping(patience=early_stopping, verbose=True)



In [None]:
def random_crop(image, label, patch_size):
    """
    Crop a random patch from the image
    """
    # get the shape of the image
    h, w = image.shape[:2]

    # get the top left corner of the random crop
    x = np.random.randint(0, w - patch_size)
    y = np.random.randint(0, h - patch_size)

    # crop the image
    image = image[y:y + patch_size, x:x + patch_size]
    label = label[y:y + patch_size, x:x + patch_size]

    return image, label

In [None]:
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out


In [None]:
# read the list.txt as a space separated dataframe
df = pd.read_csv('list.csv', sep=',', header=None)

# split df into input, label columns
input_df = df.iloc[:, 0]
label_df = df.iloc[:, 1]

# create a list of tuples
image_label_list = list(zip(input_df, label_df))

# create a dataset object
dataset = Dataset()

# create a dataloader object
dataloader = DataLoader(image_label_list, batch_size=batch_size, shuffle=False)

print('number of images in dataset: {}'.format(len(dataloader)))

In [None]:
# Print network summary
summary(Net)

In [None]:
epoch_loss = 0
cnt = 0

# loop over epochs with tqdm progress bar
t = trange(n_epochs, leave=True)
for epoch in t:
    # initialize the training and validation loss for this epoch
    train_loss_epoch = 0.0
    valid_loss_epoch = 0.0

    # set the model to training mode
    Net.train()


    # loop over the training data
    #### IMPORTANT ####
    # len(images) == 3, if we are in the small size, so perhaps we can
    # heuristically increase this in order to get better results
    # sinze now it's just 3 images per epoch
    for input, label in dataloader:
        # load the first image and label
        print(' ------------ new batch --------------')
        print(input)
        print(label)
        
        input = input[0]
        label = label[0]


        image = rawpy.imread(input)
        label = rawpy.imread(label)

        # convert the image and label to numpy arrays
        image = image.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        label = label.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)

        # convert the image and label to float32 data type
        image = np.float32(image / 65535.0)
        label = np.float32(label / 65535.0)

                # strip input and label from paranthese
        if cnt % 20 == 0:
                # display the image and label
            plt.imshow(image)
            plt.show()
            plt.imshow(label)
            plt.show()


        ####### POSTPROCESSING #######

        # crop the image and label to 512 x 512
        image, label = random_crop(image, label, patch_size)

        if cnt % 20 == 0:
                # display the image and label
            plt.imshow(image)
            plt.show()
            plt.imshow(label)
            plt.show()

        # convert the image and label to tensors
        image = torch.from_numpy(np.expand_dims(np.transpose(image, (2, 0, 1)), axis=0))
        label = torch.from_numpy(np.expand_dims(np.transpose(label, (2, 0, 1)), axis=0))

        # move the image and label to the device
        image, label = image.to(device), label.to(device)

        # forward pass
        output = Net(image)

        if cnt % 20 == 0:
            # also display the input patch
            plt.imshow(image[0].cpu().detach().numpy().transpose(1, 2, 0))
            plt.imshow(output[0].cpu().detach().numpy().transpose(1, 2, 0))
            plt.show()

        # calculate the loss
        loss = criterion(output, label)

        # backward pass
        loss.backward()

        # update the weights
        optimizer.step()

        # clear the gradients
        optimizer.zero_grad()

        # print the loss
        # print("==> Loss: {}".format(loss.item()))
        epoch_loss = loss.item()

        # update the training loss
        train_loss_epoch += loss.item()

        # append loss to the list
        train_loss.append(loss.item())

        cnt = cnt + 1  
        

# plot loss
plt.plot(train_loss)
plt.title("Training Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

# plot 