In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19, VGG19_Weights
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import os
import argparse
import numpy as np

In [54]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using:", device)

Using: cuda


In [55]:
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

In [57]:
import sys
sys.path.append('../models')

from SuperResolutionCNN import SuperResolutionCNN

student = SuperResolutionCNN().to(device)
student.apply(weights_init)

SuperResolutionCNN(
  (entry): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (res_blocks): Sequential(
    (0): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (1): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (2): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (3): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1

In [58]:
sys.path.append('C:/Users/manas/image-sharpness/EDSR-PyTorch/src')
from model.edsr import make_model


In [59]:
args_teacher = argparse.Namespace(
    n_resblocks=16,         # number of residual blocks used in training
    n_feats=256,            # number of feature maps
    res_scale=1.0,          # residual scaling factor
    scale=[1],              # upscaling factor (1 for sharpness task)
    n_colors=3,             # RGB image (3 channels)
    rgb_range=1           # pixel value range (0-255)
)
teacher = make_model(args_teacher).to(device)
teacher.load_state_dict(torch.load('C:/Users/manas/image-sharpness/EDSR-PyTorch/experiment/edsr_sharpness_finetune_x1/model/model_1.pt'))
teacher.eval()
for param in teacher.parameters():
    param.requires_grad = False

  teacher.load_state_dict(torch.load('C:/Users/manas/image-sharpness/EDSR-PyTorch/experiment/edsr_sharpness_finetune_x1/model/model_1.pt'))


In [60]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        weights = VGG19_Weights.DEFAULT
        vgg = vgg19(weights=weights).features[:36].eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg
        self.layer_ids = [3, 8, 17, 26, 35]
        self.criterion = nn.L1Loss()

    def forward(self, sr, hr):
        x, y = sr, hr
        loss = 0
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            y = layer(y)
            if i in self.layer_ids:
                loss += self.criterion(x, y)
        return loss

In [61]:
class FeatureDistillationLoss(nn.Module):
    def __init__(self, teacher_channels_list):
        super().__init__()
        self.criterion = nn.L1Loss()
        self.adapters = nn.ModuleList([
            nn.Conv2d(tc, 64, kernel_size=1) for tc in teacher_channels_list
        ])

    def forward(self, feat_s, feat_t):
        loss = 0
        for fs, ft, adapter in zip(feat_s, feat_t, self.adapters):
            adapted_ft = adapter(ft)
            loss += self.criterion(fs, adapted_ft.detach())
        return loss

In [62]:
class EdgeLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.L1Loss()

    def sobel(self, x):
        sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).view(1, 1, 3, 3)
        sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).view(1, 1, 3, 3)
        sobel_x = sobel_x.repeat(x.size(1), 1, 1, 1).to(x.device)
        sobel_y = sobel_y.repeat(x.size(1), 1, 1, 1).to(x.device)
        grad_x = F.conv2d(x, sobel_x, padding=1, groups=x.size(1))
        grad_y = F.conv2d(x, sobel_y, padding=1, groups=x.size(1))
        return torch.sqrt(grad_x ** 2 + grad_y ** 2)

    def forward(self, sr, hr):
        return self.criterion(self.sobel(sr), self.sobel(hr))

In [63]:
rec_loss_fn = nn.L1Loss()
perc_loss_fn = VGGPerceptualLoss().to(device)
feat_loss_fn = FeatureDistillationLoss(teacher_channels_list=[256,256,256]).to(device)
edge_loss_fn = EdgeLoss()

In [64]:
from torch.utils.data import DataLoader
from data.sharpness import Sharpness

class Args:
    def __init__(self):
        self.scale = [1]
        self.dir_data = 'C:/Users/manas/image-sharpness/EDSR-PyTorch/src/data'  # or wherever your `sharpness/train/HR, LR` dirs are
        self.batch_size = 16
        self.patch_size = 48
        self.n_colors = 3
        self.rgb_range = 255
        self.test_every = 1000

args = Args()

train_dataset = Sharpness(args, train=True)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_dataset = Sharpness(args, train=False)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

In [65]:
optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)


In [66]:
lr, hr = train_dataset[0]
print(lr.shape, hr.shape)

torch.Size([3, 48, 48]) torch.Size([3, 48, 48])


In [67]:
from tqdm import tqdm
losses = []
vgg_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
vgg_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
for epoch in range(5):
    student.train()
    epoch_loss = 0

    for lr, hr in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        lr, hr = lr.to(device), hr.to(device)
        lr = lr / 255.0
        hr = hr / 255.0

        if torch.isnan(lr).any():
            print("❌ NaN detected in Low-Resolution (lr) input!")
            raise ValueError("NaN in lr")

        if torch.isnan(hr).any():
            print("❌ NaN detected in High-Resolution (hr) ground truth!")
            raise ValueError("NaN in hr")

        with torch.no_grad():
            t_out, t_feats = teacher(lr, return_features=True)
            print('Teacher output stats:', torch.isnan(t_out).any(), t_out.min(), t_out.max())
        if torch.isnan(t_out).any():
            print("❌ NaN in teacher output")
            raise ValueError("NaN in teacher output")
            

        s_out, s_feats = student(lr, return_features=True)
        print('Student output stats:', torch.isnan(s_out).any(), s_out.min(), s_out.max())

        if torch.isnan(s_out).any():
            print("❌ NaN in student output")
            raise ValueError("NaN in student output")

        print("✅ Student output range:", s_out.min().item(), s_out.max().item())

        s_out_norm = (s_out - vgg_mean) / vgg_std
        hr_norm = (hr - vgg_mean) / vgg_std

        rec_loss = rec_loss_fn(s_out, hr)
        perc_loss = perc_loss_fn(s_out_norm, hr_norm)
        feat_loss = feat_loss_fn(s_feats, t_feats)
        grad_loss = edge_loss_fn(s_out, hr)

        total_loss = (
            1.0 * rec_loss +
            0.1 * perc_loss +
            0.1 * feat_loss +
            0.1 * grad_loss
        )
        if torch.isnan(total_loss):
            print("❌ NaN detected in total_loss!")
            print(f"rec_loss: {rec_loss.item()}")
            print(f"perc_loss: {perc_loss.item()}")
            print(f"feat_loss: {feat_loss.item()}")
            print(f"grad_loss: {grad_loss.item()}")
            raise ValueError("NaN detected — stopping training")

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        epoch_loss += total_loss.item()
        losses.append(total_loss.item())

    print(f"Epoch [{epoch+1}], Loss: {epoch_loss/len(train_loader):.4f}")

    # Plotting loss
    plt.plot(losses)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    plt.title(f"Training Loss after Epoch {epoch+1}")
    plt.show()

    # Save checkpoint
    torch.save(student.state_dict(), f"student_epoch{epoch+1}.pth")

    # ---------- Evaluation on Validation Set ----------
    student.eval()
    total_psnr = 0
    total_ssim = 0
    count = 0

    with torch.no_grad():
        for lr, hr in val_loader:
            lr, hr = lr.to(device), hr.to(device)
            lr = lr / 255.0
            hr = hr / 255.0
            pred = student(lr)

            pred_img = pred[0].detach().cpu().numpy().transpose(1, 2, 0)
            hr_img = hr[0].detach().cpu().numpy().transpose(1, 2, 0)

            pred_img = np.clip(pred_img, 0, 1)
            hr_img = np.clip(hr_img, 0, 1)

            total_psnr += psnr(hr_img, pred_img, data_range=1.0)
            total_ssim += ssim(hr_img, pred_img, channel_axis=-1, data_range=1.0)
            count += 1

    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    psnr_scores.append(avg_psnr)
    ssim_scores.append(avg_ssim)

    print(f"Validation PSNR: {avg_psnr:.2f} | SSIM: {avg_ssim:.4f}")

    # Visualize example
    grid = make_grid([lr[0].cpu(), pred[0].cpu(), hr[0].cpu()], nrow=3)
    plt.imshow(grid.permute(1, 2, 0))
    plt.title("LR | Predicted | HR")
    plt.axis("off")
    plt.show()


Epoch 1:   0%|          | 0/2776 [00:00<?, ?it/s]

Teacher output stats: 

Epoch 1:   0%|          | 1/2776 [00:01<48:53,  1.06s/it]

tensor(False, device='cuda:0') tensor(-0.0885, device='cuda:0') tensor(1.1167, device='cuda:0')
Student output stats: tensor(False, device='cuda:0') tensor(-155.9790, device='cuda:0', grad_fn=<MinBackward1>) tensor(92.9339, device='cuda:0', grad_fn=<MaxBackward1>)
✅ Student output range: -155.97897338867188 92.93385314941406
Teacher output stats: 

Epoch 1:   0%|          | 1/2776 [00:01<1:14:38,  1.61s/it]

tensor(False, device='cuda:0') tensor(-0.0231, device='cuda:0') tensor(1.0376, device='cuda:0')
Student output stats: tensor(True, device='cuda:0') tensor(nan, device='cuda:0', grad_fn=<MinBackward1>) tensor(nan, device='cuda:0', grad_fn=<MaxBackward1>)
❌ NaN in student output





ValueError: NaN in student output