In [79]:
from unet import *
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torch.optim as optim
import tqdm 
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [82]:
Net = build_unet()
# check if CUDA is available, and set it as the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("==> Device: {}".format(device))

# move model to the device
Net.to(device)

# define loss function
criterion = nn.L1Loss()

####### HYPERPARAMETERS #######

# learning rate = 10^-4
lr = 0.0001

# patch size
patch_size = 512

# define optimizer
optimizer = optim.Adam(Net.parameters(), lr=lr)

# batch size
batch_size = 1

# define number of epochs
n_epochs = 10

# keep track of the best validation loss
valid_loss_min = np.Inf

# number of epochs to wait before stopping
early_stopping = 5


##### LOSS AND ACCURACY #######

# keep track of training and validation loss
train_loss = []
valid_loss = []

# keep track of training and validation accuracy
train_acc = []
valid_acc = []

# initialize the early_stopping object
# early_stopping = EarlyStopping(patience=early_stopping, verbose=True)


==> Device: cuda


In [90]:
def random_crop(img, patch_size):
    """
    Crop a random patch from the image
    """
    h, w = img.shape[:2]
    x = np.random.randint(0, w - patch_size)
    y = np.random.randint(0, h - patch_size)
    return img[y:y + patch_size, x:x + patch_size]

In [91]:
####### SAMPLE INPUT DATA #######
labels_dir = '/home/deeplearning/images/SonyImages/Sony/long/'
images_dir = '/home/deeplearning/images/SonyImages/Sony/short/'

images = glob.glob(images_dir + '0*.ARW')
labels = glob.glob(labels_dir + '0*.ARW')

# sort the images and labels
images.sort()
labels.sort()

# loop over epochs with tqdm progress bar
for epoch in tqdm.tqdm(range(n_epochs)):
    # initialize the training and validation loss for this epoch
    train_loss_epoch = 0.0
    valid_loss_epoch = 0.0

    # initialize the training and validation accuracy for this epoch
    train_acc_epoch = 0.0
    valid_acc_epoch = 0.0

    # set the model to training mode
    Net.train()

    # loop over the training data
    for i in range(len(images)):
        # load the first image and label
        image = rawpy.imread(images[i])
        label = rawpy.imread(labels[i])

        # convert the image and label to numpy arrays
        image = image.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        label = label.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)

        # convert the image and label to float32 data type
        image = np.float32(image / 65535.0)
        label = np.float32(label / 65535.0)

        # crop the image and label to 512 x 512
        image = random_crop(image, 512)
        label = random_crop(label, 512)

        # convert the image and label to tensors
        image = torch.from_numpy(np.expand_dims(np.transpose(image, (2, 0, 1)), axis=0))
        label = torch.from_numpy(np.expand_dims(np.transpose(label, (2, 0, 1)), axis=0))

        # move the image and label to the device
        image, label = image.to(device), label.to(device)

        # forward pass
        output = Net(image)

        # calculate the loss
        loss = criterion(output, label)

        # backward pass
        loss.backward()

        # update the weights
        optimizer.step()

        # clear the gradients
        optimizer.zero_grad()

        # print the loss
        print("==> Loss: {}".format(loss.item()))

        # update the training loss
        train_loss_epoch += loss.item()


  return F.l1_loss(input, target, reduction=self.reduction)


==> Loss: 0.35670313239097595
==> Loss: 0.3889467120170593
==> Loss: 0.23229238390922546
==> Loss: 0.18158477544784546
==> Loss: 0.30527669191360474
==> Loss: 0.25263768434524536
==> Loss: 0.22558127343654633
==> Loss: 0.2563576400279999
==> Loss: 0.25788840651512146
==> Loss: 0.20300433039665222
==> Loss: 0.13902375102043152
==> Loss: 0.15284602344036102
==> Loss: 0.11452110856771469
==> Loss: 0.10676401853561401


  0%|          | 0/10 [00:11<?, ?it/s]


KeyboardInterrupt: 