In [2]:
from unet import *
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
Net = build_unet()
# check if CUDA is available, and set it as the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("==> Device: {}".format(device))

# move model to the device
Net.to(device)

# define loss function
# criterion = nn.L1Loss()
criterion = nn.MSELoss()

####### HYPERPARAMETERS #######

# learning rate = 10^-4
lr = 0.0001

# patch size
patch_size = 512

# define optimizer
optimizer = optim.Adam(Net.parameters(), lr=lr)

# batch size
batch_size = 1

# define number of epochs
n_epochs = 10

# keep track of the best validation loss
valid_loss_min = np.Inf

# number of epochs to wait before stopping
early_stopping = 5


##### LOSS  #######

# keep track of training and validation loss
train_loss = []
valid_loss = []

# initialize the early_stopping object
# early_stopping = EarlyStopping(patience=early_stopping, verbose=True)



==> Device: cuda


In [4]:
def random_crop(img, patch_size):
    """
    Crop a random patch from the image
    """
    h, w = img.shape[:2]
    x = np.random.randint(0, w - patch_size)
    y = np.random.randint(0, h - patch_size)
    return img[y:y + patch_size, x:x + patch_size]

In [42]:
####### SAMPLE INPUT DATA #######
labels_dir = '/home/deeplearning/images/SonyImages/Sony/long/'
images_dir = '/home/deeplearning/images/SonyImages/Sony/short/'

# labels_fns = glob.glob(labels_dir + '0*.ARW')
# images_fns = glob.glob(images_dir + '0*.ARW')

# label_ids = [int(os.path.basename(label_fn)[0:5]) for label_fn in labels_fns]
# image_ids = [int(os.path.basename(image_fn)[0:5]) for image_fn in images_fns]
# print(label_ids)
# print(image_ids)

# # # sort the images and labels
# # labels_fns.sort()
# # images_fns.sort()

# # only train on 10 images, discard the rest
# labels_fns = labels_fns[:10]
# images_fns = images_fns[:10]

# # # find all the input images from image_dir tha tcontain the same train_ids
# images = [glob.glob(images_dir + '%05d_00*.ARW' % train_id) for train_id in image_ids]
# labels = [glob.glob(labels_dir + '%05d_00*.ARW' % train_id) for train_id in label_ids]

# # print them
# print(labels)
# print(images)


# train_fns = glob.glob(labels_dir + '0*.ARW')
# train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

# labels = train_ids[:10]

# # find all the input images from image_dir tha tcontain the same train_ids 
# images = [glob.glob(images_dir + '%05d_00*.ARW' % train_id) for train_id in train_ids]
# labels = [glob.glob(labels_dir + '%05d_00*.ARW' % train_id) for train_id in train_ids]

# # sort them
# train_ids.sort()
# images.sort()

# # print them
# print(train_ids)
# print(images)
# 


images = glob.glob(images_dir + '0*.ARW')
labels = glob.glob(labels_dir + '0*.ARW')

# # sort the images and labels
images.sort()
labels.sort()

# only train on 10 images, discard the rest
images = images[:10]
labels = labels[:10]

epoch_loss = 0

# loop over epochs with tqdm progress bar
t = trange(n_epochs, leave=True)
for epoch in t:
    # initialize the training and validation loss for this epoch
    train_loss_epoch = 0.0
    valid_loss_epoch = 0.0

    # set the model to training mode
    Net.train()

    # loop over the training data
    for i in range(len(images)):
        # load the first image and label
        image = rawpy.imread(images[i])
        label = rawpy.imread(labels[i])

        # convert the image and label to numpy arrays
        image = image.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        label = label.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)

        # convert the image and label to float32 data type
        image = np.float32(image / 65535.0)
        label = np.float32(label / 65535.0)

        # crop the image and label to 512 x 512
        image = random_crop(image, 512)
        label = random_crop(label, 512)

        # display the image and label
        # plt.imshow(image)
        # plt.show()
        # plt.imshow(label)
        # plt.show()

        # display file name of label and input to check if they are correct
        print("==> Label: {}".format(labels[i]))
        print("==> Image: {}".format(images[i]))
        print("==> Image shape: {}".format(image.shape))
        print("==> Label shape: {}".format(label.shape))        

        # convert the image and label to tensors
        image = torch.from_numpy(np.expand_dims(np.transpose(image, (2, 0, 1)), axis=0))
        label = torch.from_numpy(np.expand_dims(np.transpose(label, (2, 0, 1)), axis=0))

        # move the image and label to the device
        image, label = image.to(device), label.to(device)

        # forward pass
        output = Net(image)

        # calculate the loss
        loss = criterion(output, label)

        # backward pass
        loss.backward()

        # update the weights
        optimizer.step()

        # clear the gradients
        optimizer.zero_grad()

        # print the loss
        print("==> Loss: {}".format(loss.item()))
        epoch_loss = loss.item()

        # update the training loss
        train_loss_epoch += loss.item()

        # append loss to the list
        train_loss.append(loss.item())


# plot loss
plt.plot(train_loss)
plt.title("Training Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

# plot 

  0%|          | 0/10 [00:00<?, ?it/s]

==> Label: /home/deeplearning/images/SonyImages/Sony/long/00001_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_00_0.04s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)


  return F.mse_loss(input, target, reduction=self.reduction)


==> Loss: 0.0892486423254013
==> Label: /home/deeplearning/images/SonyImages/Sony/long/00002_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_00_0.1s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.06475134193897247
==> Label: /home/deeplearning/images/SonyImages/Sony/long/00004_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_01_0.04s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.03906021639704704
==> Label: /home/deeplearning/images/SonyImages/Sony/long/00009_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_01_0.1s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.0393085777759552
==> Label: /home/deeplearning/images/SonyImages/Sony/long/00010_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_02_0.1s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.0397891327738

 10%|█         | 1/10 [00:08<01:19,  8.86s/it]

==> Label: /home/deeplearning/images/SonyImages/Sony/long/00017_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_07_0.1s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.01960989646613598
==> Label: /home/deeplearning/images/SonyImages/Sony/long/00001_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_00_0.04s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.03892054408788681
==> Label: /home/deeplearning/images/SonyImages/Sony/long/00002_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_00_0.1s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.04740601032972336
==> Label: /home/deeplearning/images/SonyImages/Sony/long/00004_00_10s.ARW
==> Image: /home/deeplearning/images/SonyImages/Sony/short/00001_01_0.04s.ARW
==> Image shape: (512, 512, 3)
==> Label shape: (512, 512, 3)
==> Loss: 0.015070416033267975


 10%|█         | 1/10 [00:11<01:44, 11.64s/it]


KeyboardInterrupt: 