In [None]:
import os
import random

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn

import torch.optim as optim
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter

from data import get_dataloaders
from model import Dehazer
from loss import HybridLoss

In [None]:
%load_ext autoreload
%autoreload 2

random.seed(43)
torch.manual_seed(43)

In [None]:
# options: ["ohaze", "dh/Middlebury", "dh/NYU"]
DATASET = "ohaze"
BATCH_SIZE = 8
TRAIN_SPLIT = 0.8
train_loader, test_loader = get_dataloaders(DATASET, TRAIN_SPLIT, BATCH_SIZE)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
""" Define the hybrid loss """

# VGG for perceptual loss
vgg_model = models.vgg16(weights="VGG16_Weights.DEFAULT").features
feat_extractor = nn.Sequential(*list(vgg_model.children())[:24])  # from relu4_2 layer
feat_extractor.add_module("avgpool",nn.AdaptiveAvgPool2d((1, 1)))

# ResNet for perceptual loss
# resnet = models.resnet18(weights="VGG16_Weights.DEFAULT")
# feat_extractor = nn.Sequential(*list(resnet.children())[:7])
# feat_extractor.add_module("avgpool",nn.AdaptiveAvgPool2d((1, 1)))

num_params_f = sum(torch.numel(p) for p in feat_extractor.parameters())
print(f"Number of parameters in the feature extractor: {num_params_f}")

# Loss
GAMMA = 1.5
feat_extractor.to(device)
criterion = HybridLoss(feat_extractor, gamma=GAMMA)

In [None]:
# get the model
model = Dehazer()
num_params_m = sum(torch.numel(p) for p in model.parameters())
print(f"Number of parameters in the model: {num_params_m}")

In [None]:
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def test(model, test_loader, writer, epoch):
    model.eval()

    with torch.no_grad():
        for images, targets in test_loader:
            images = images.to(device)

            # Forward pass
            outputs = model(images)

            # Convert the outputs and targets to numpy arrays
            outputs_np = outputs.cpu().numpy()
            targets_np = targets.numpy()

            num_outs = outputs.shape[0]
            fig, axs = plt.subplots(2, num_outs, figsize=(20, 8))

            for i in range(num_outs):

                # get channel-wise max
                maxes = []
                for c in range(3):
                    maxes.append(outputs_np[i][c].max())
                maxes = np.array(maxes)

                # normalize using max values
                for c in range(3):
                    outputs_np[i][c] /= maxes[c]
                outputs_np[i] *= 255

                output = outputs_np[i].astype("uint8")

                # Plot ground truth image in the first row
                axs[0, i].imshow(targets_np[i].transpose(1, 2, 0))
                axs[0, i].set_title("Ground Truth")
                axs[0, i].axis("off")

                # Plot predicted image in the second row
                axs[1, i].imshow((output).transpose(1, 2, 0))
                axs[1, i].set_title("Output")
                axs[1, i].axis("off")

            plt.tight_layout()
            plt.savefig(f'plot_{epoch}.png')

            # Load the image file as a tensor
            image = torch.from_numpy(np.array(Image.open('plot.png'))).permute(2, 0, 1).float() / 255

            # Add the image to the SummaryWriter
            writer.add_image('Plot', image)
            plt.show()
            break

    writer.close()


In [None]:
def train(model, train_loader, test_loader, num_epochs):
    model.to(device)

    log_dir = f"./logs/{DATASET}/lr{learning_rate}_gamma{GAMMA}_epochs{num_epochs}"
    writer = SummaryWriter(log_dir)
    weights_dir = os.path.join(log_dir, "weights")
    os.makedirs(weights_dir, exist_ok=True)

    for epoch in range(num_epochs):

            running_loss = 0.0
            for idx, (images, targets) in enumerate(train_loader):
                images = images.to(device)
                targets = targets.to(device)

                # Forward pass
                outputs = model(images)

                # Compute the loss
                loss = criterion(outputs, targets)

                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            epoch_loss = running_loss / len(train_loader)
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}")

            if (epoch+1)%3 == 0:
                test(model, test_loader, writer, epoch)
                model.train()
                checkpoint = {
                        'epoch' : epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                torch.save(checkpoint, os.path.join(weights_dir, f"{(epoch+1):03}.pth"))

    writer.close()

In [None]:
NUM_EPOCHS=30
train(model, train_loader, test_loader, NUM_EPOCHS)