<a href="https://colab.research.google.com/github/jrtrj/ImageSharpening_KD/blob/local/ImageSharpening.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch-msssim
!pip install albumentations==1.3.1

import os
import cv2
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from pytorch_msssim import ssim as ssim_loss
import albumentations as A
from albumentations.pytorch import ToTensorV2
from skimage.metrics import structural_similarity as ssim
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt

In [16]:
class DIV2KDatasetAugmented(Dataset):
    def __init__(self, hr_dir):
        self.hr_paths = sorted(glob.glob(os.path.join(hr_dir, "*.png")))
        self.degradation_transform = A.Compose([
            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 7), p=0.7),
                A.MotionBlur(blur_limit=(3, 7), p=0.3),
            ], p=1.0),
            A.Downscale(scale_min=0.6, scale_max=0.8, interpolation=cv2.INTER_CUBIC, p=0.5),
            A.GaussNoise(var_limit=(10, 50), p=0.3),
        ])

        self.to_tensor = A.Compose([
            A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
            A.pytorch.ToTensorV2()
        ])

    def __len__(self):
        return len(self.hr_paths)

    def __getitem__(self, idx):
        img_hr_np = cv2.imread(self.hr_paths[idx])
        img_hr_np = cv2.cvtColor(img_hr_np, cv2.COLOR_BGR2RGB)
        img_hr_np = cv2.resize(img_hr_np, (128, 128))

        img_degraded_np = self.degradation_transform(image=img_hr_np)['image']

        img_hr = self.to_tensor(image=img_hr_np)['image']
        img_degraded = self.to_tensor(image=img_degraded_np)['image']

        return img_degraded, img_hr

In [18]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1, self.bn1 = nn.Conv2d(channels, channels, 3, 1, 1), nn.BatchNorm2d(channels)
        self.conv2, self.bn2 = nn.Conv2d(channels, channels, 3, 1, 1), nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x): return self.relu(x + self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))))

class ResNetSharpen(nn.Module):
    def __init__(self, num_blocks=8):
        super(ResNetSharpen, self).__init__()
        self.conv_in = nn.Conv2d(3, 64, 3, 1, 1)
        self.relu = nn.ReLU(inplace=True)
        self.residual_layers = nn.Sequential(*[ResidualBlock(64) for _ in range(num_blocks)])
        self.conv_out = nn.Conv2d(64, 3, 3, 1, 1) # Output should have 3 channels to match input

    def forward(self, x):
        identity = x
        out = self.conv_in(x)
        out = self.relu(out)
        out = self.residual_layers(out)
        out = self.conv_out(out)
        return identity + out # Add the residual connection here

class StudentCNN(nn.Module):
    def __init__(self):
        super(StudentCNN, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(3,32,3,1,1), nn.ReLU(True), nn.Conv2d(32,32,3,1,1), nn.ReLU(True), nn.Conv2d(32,3,3,1,1))
    def forward(self, x): return x - self.body(x)

In [19]:
# Perceptual Loss using a pre-trained VGG19
from torchvision.models import vgg19, VGG19_Weights

class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        weights = VGG19_Weights.DEFAULT
        vgg = vgg19(weights=weights).features.to(device).eval()
        for param in vgg.parameters(): param.requires_grad = False
        self.feature_extractor = nn.Sequential(*list(vgg.children())[:36])
        self.loss_fn = nn.L1Loss()

    def forward(self, pred, target):
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        return self.loss_fn(pred_features, target_features)

# A total loss function combining all criteria
class TotalLoss(nn.Module):
    def __init__(self, ssim_w=0.8, perceptual_w=0.01):
        super(TotalLoss, self).__init__()
        self.ssim_w = ssim_w
        self.perceptual_w = perceptual_w
        self.mse = nn.MSELoss()
        self.ssim = lambda p, t: 1 - ssim_loss(p, t, data_range=1.0, size_average=True)
        self.perceptual = VGGPerceptualLoss()

    def forward(self, pred, target):
        ssim_mse_loss = self.ssim_w * self.ssim(pred, target) + (1 - self.ssim_w) * self.mse(pred, target)
        perceptual_loss = self.perceptual(pred, target)
        return ssim_mse_loss + self.perceptual_w * perceptual_loss

In [20]:
data_dir = "/content/DIV2K_train_HR"
full_dataset = DIV2KDatasetAugmented(data_dir)

# 90% for training, 10% for validation
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size

generator = torch.Generator().manual_seed(42)

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 360
Validation samples: 40


In [None]:
# : Train the Teacher Model ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = ResNetSharpen(num_blocks=8).to(device)

criterion = TotalLoss(ssim_w=0.8, perceptual_w=0.01).to(device)
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-4)
epochs = 10
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

print("--- Starting Training for the teacher Model ---")

def evaluate_ssim(model, loader):
    print("Evaluating SSIM...")
    model.eval()
    total_ssim = 0
    count = 0
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            score = ssim_loss(outputs, targets, data_range=1.0, size_average=True)
            total_ssim += score.item() * inputs.size(0)
            count += inputs.size(0)
    return total_ssim / count


for epoch in range(epochs):
    teacher_model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = teacher_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    scheduler.step()
    train_loss = running_loss / len(train_loader)

    val_ssim = evaluate_ssim(teacher_model, val_loader)

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.5f} | Val SSIM: {val_ssim:.4f}")

torch.save(teacher_model.state_dict(), "teacher.pth")

In [22]:
torch.save(teacher_model.state_dict(), "teacher.pth")

In [None]:
teacher_model.load_state_dict(torch.load("teacher.pth"))
teacher_model.eval()

student_model = StudentCNN().to(device)
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
distill_epochs = 20

def distillation_loss(s_out, t_out, target, alpha=0.75):
    loss_teacher = nn.MSELoss()(s_out, t_out)
    loss_gt = nn.MSELoss()(s_out, target)
    return alpha * loss_teacher + (1 - alpha) * loss_gt

print("\n--- Starting Knowledge Distillation from Teacher ---")

for epoch in range(distill_epochs):
    student_model.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            teacher_out = teacher_model(inputs)

        student_out = student_model(inputs)

        loss = distillation_loss(student_out, teacher_out, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    student_val_ssim = evaluate_ssim(student_model, val_loader)
    print(f"[Student] Epoch {epoch+1}/{distill_epochs} | Val SSIM: {student_val_ssim:.4f}")

torch.save(student_model.state_dict(), "student_final.pth")

In [None]:
teacher_model.eval()
student_model.eval()

print("\n--- Final Model Performance on the Unseen Validation Set ---")

final_teacher_ssim = evaluate_ssim(teacher_model, val_loader)
final_student_ssim = evaluate_ssim(student_model, val_loader)

print(f"Teacher SSIM: {final_teacher_ssim:.4f}")
print(f"Final Distilled Student SSIM: {final_student_ssim:.4f}")


def visualize_results(dataset, num_images=5):
    print("\n--- Visualizing Model Outputs ---")

    for i in range(num_images):
        input_img, target_img = dataset[i]

        input_tensor = input_img.unsqueeze(0).to(device)

        with torch.no_grad():
            student_out = student_model(input_tensor)
            teacher_out = teacher_model(input_tensor)

        input_np = input_img.cpu().permute(1, 2, 0).numpy()
        target_np = target_img.cpu().permute(1, 2, 0).numpy()
        student_np = student_out.squeeze(0).cpu().permute(1, 2, 0).numpy()
        teacher_np = teacher_out.squeeze(0).cpu().permute(1, 2, 0).numpy()

        ssim_score = ssim(target_np, student_np, channel_axis=2, data_range=1.0)

        fig, axes = plt.subplots(1, 4, figsize=(20, 5))

        axes[0].imshow(input_np)
        axes[0].set_title("Degraded Input")
        axes[0].axis("off")

        axes[1].imshow(student_np)
        axes[1].set_title(f"Student Output\nSSIM: {ssim_score:.4f}")
        axes[1].axis("off")

        axes[2].imshow(teacher_np)
        axes[2].set_title("Teacher Output")
        axes[2].axis("off")

        axes[3].imshow(target_np)
        axes[3].set_title("Ground Truth (Original)")
        axes[3].axis("off")

        plt.show()

visualize_results(val_dataset)