In [None]:
# Run once, restart kernel and proceed with the next cell
# %pip install --upgrade ipywidgets

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
import json
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

import training
from training import load_config, generate_basic_dataloader, generate_sppf_dataloader, get_model, train_loop, cal_regression_metrics
from utils.preprocessing import load_image, apply_img_preprocessing

from inference import load_saved_model, pred_degradation_value

# Training

In [None]:
config_path = "./configs/cnn.yaml"
config = load_config(config_path)
pprint(config,)

In [None]:
# copy paste config here to edit and experiment
config = {'dataset_loc': {'train': {'degradation_values_csv': './data/bdd100k/segments/degradation_segment_labels_train.csv',
                           'img_dir': './data/bdd100k/segments/train/'},
                 'val': {'degradation_values_csv': './data/bdd100k/segments/degradation_segment_labels_val.csv',
                         'img_dir': './data/bdd100k/segments/val/'}},
         'enable_cuda': True,
         'model': {'in_channels': 3, 'out_dim': 1},
         'results_loc': 'experiment_results/',
         'training': {'batch_size': 1,
                      'learning_rate': 0.05,
                      'num_epochs': 10,
                      'num_workers': 2,
                      'resume_checkpoint': None,
                      'save_checkpoint_freq': 1}
         }

In [None]:
# get required config parameters
model_config = config["model"]
train_config = config["training"]
dataset_config = config["dataset_loc"]

In [None]:
if config["enable_cuda"]:
    training.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using DEVICE: {training.DEVICE}")

In [None]:
# update preprocessing config if required
# resize_height, resize_width = preprocess_config["resize_height"], preprocess_config["resize_height"]

# # in case one wants to try out different transformations
# img_transform = transforms.Compose([
#     transforms.ToPILImage(),
#     transforms.Resize((resize_height, resize_width)),   # ensure resize same is used for mask by setting preprocess_config
#     transforms.ToTensor()
# ])

# generate train data loader
train_loader, train_size = generate_sppf_dataloader(image_dir=dataset_config["train"]["img_dir"],
                                               degradation_values_csv=dataset_config["train"]["degradation_values_csv"],
                                               batch_size=train_config["batch_size"],
                                               num_workers=train_config["num_workers"], transform=None,   # send new transform here if required
                                               subset_size=1000)

# generate validation data loader
val_loader, val_size = generate_sppf_dataloader(image_dir=dataset_config["val"]["img_dir"],
                                                degradation_values_csv=dataset_config["val"]["degradation_values_csv"],
                                                batch_size=train_config["batch_size"],
                                                num_workers=train_config["num_workers"], transform=None, 
                                                subset_size=500)

print(f"Train Dataset loaded. #samples: {train_size}")
print(f"Validation Dataset loaded. #samples: {val_size}")

In [None]:
# Ensure that we are getting correct data from data loaders
batch_img, degradation_values = next(iter(train_loader))

print("Image batch shape:", batch_img.shape) 
# print("Degradation value:", degradation_values.shape) 

sample_img = batch_img[0].numpy()   # (c, h, w)
sample_img = sample_img.transpose(1, 2, 0)

# sample_mask = batch_mask[0].numpy()   # (c, h, w)
# sample_mask = np.squeeze(sample_mask) # (h, w)

fig, ax = plt.subplots(figsize=(15, 5))
ax.imshow(sample_img)
ax.set_title("input image")

print("Degradation value:", degradation_values[0].item())

In [None]:
# Initializing the model, loss function, and the optimizer
model_name = "cnn_sppf"
model = get_model(model_name, in_channels=model_config['in_channels'], out_dim=model_config['out_dim'])
model = model.to(training.DEVICE)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=train_config["learning_rate"])

checkpoint_path = train_config["resume_checkpoint"]
if checkpoint_path is not None:
    model.load_state_dict((torch.load(checkpoint_path, weights_only=True)))

In [None]:
# from torchinfo import summary
# summary(model, input_size=(1, 3, 4, 4))

In [None]:
# train the model
train_loop(model=model, loss_fn=criterion, optimizer=optimizer,
           train_loader=train_loader, val_loader=val_loader,
           num_epochs=train_config["num_epochs"], save_path=config["results_loc"],
           checkpoint_freq=train_config["save_checkpoint_freq"])

In [None]:
# open saved learning curve
plot_saved_path = "experiment_results/train_log/learning_curve_2025-03-15_00-08-40.png"
img = load_image(plot_saved_path)
plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
# Evaluate model performance at end of training using different losses
train_losses = cal_regression_metrics(model, train_loader)
val_losses = cal_regression_metrics(model, val_loader)

print(f"Train Loss: {train_losses}")
print(f"Validation Loss: {val_losses}")

# Inference

In [None]:
saved_weight_path = "experiment_results/checkpoints/cnn_sppf_final_2025-03-13_18-02-43.pth"

model_name = "cnn_sppf"
model_config = {'in_channels': 3, 'out_dim': 1}

In [None]:
# update preprocessing according to training
resize_height, resize_width = 720, 1280

# Define the image transformations
img_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((resize_height, resize_width)),   # ensure resize is same as used during training for loaded model 
    transforms.ToTensor()
])

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using DEVICE: {DEVICE}")

In [None]:
# initialize and load saved model
model = load_saved_model(model_name=model_name, saved_weight_path=saved_weight_path, **model_config)
model = model.to(DEVICE)

In [None]:
test_img_path = "../damage_ratio_calc_data/segmented_objects/191_jpg.rf.e27c030e763e58ce48964e670158b6e7/191_jpg.rf.e27c030e763e58ce48964e670158b6e7_object_4.png"
test_img = load_image(test_img_path)
print("Test image shape:", test_img.shape)
plt.imshow(test_img)
plt.axis("off")
plt.title("Input image", fontsize=9)
plt.show()

In [None]:
pred_val = pred_degradation_value(model=model, test_img=test_img, img_transform=None, add_batch_dim=True, device=DEVICE)
print("Predicticted degradation value:", pred_val)