In [None]:
import os
import cv2
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt

In [None]:
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, transform=None):
        self.hr_paths = sorted(glob.glob(os.path.join(hr_dir, "*.png")))
        self.transform = transform

    def __len__(self):
        return len(self.hr_paths)

    def __getitem__(self, idx):
        img_hr = cv2.imread(self.hr_paths[idx])
        img_hr = cv2.cvtColor(img_hr, cv2.COLOR_BGR2RGB)
        img_hr = cv2.resize(img_hr, (128, 128))  # Manage memory

        img_lr = cv2.resize(img_hr, (64, 64), interpolation=cv2.INTER_CUBIC)
        img_blur = cv2.resize(img_lr, (128, 128), interpolation=cv2.INTER_CUBIC)

        if self.transform:
            img_hr = self.transform(img_hr)
            img_blur = self.transform(img_blur)

        return img_blur, img_hr


In [None]:
class DnCNN(nn.Module):
    def __init__(self):
        super(DnCNN, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(inplace=True),
            *[nn.Sequential(nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) for _ in range(5)],
            nn.Conv2d(64, 3, 3, padding=1)
        )

    def forward(self, x):
        return x - self.body(x)


In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

data_dir = "/content/DIV2K_train_HR"
dataset = DIV2KDataset(data_dir, transform)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:
teacher_model = DnCNN().to("cuda")
criterion = nn.MSELoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-4)
epochs = 10

for epoch in range(epochs):
    teacher_model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to("cuda"), targets.to("cuda")
        optimizer.zero_grad()
        outputs = teacher_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")


Epoch 1/10 | Loss: 0.0228
Epoch 2/10 | Loss: 0.0103
Epoch 3/10 | Loss: 0.0098
Epoch 4/10 | Loss: 0.0097
Epoch 5/10 | Loss: 0.0096
Epoch 6/10 | Loss: 0.0095
Epoch 7/10 | Loss: 0.0094
Epoch 8/10 | Loss: 0.0094
Epoch 9/10 | Loss: 0.0094
Epoch 10/10 | Loss: 0.0094


In [None]:
torch.save(teacher_model.state_dict(), "dncnn_teacher.pth")

In [None]:
# SSIM Evaluation
def average_ssim(model, dataset, num_samples=100):
    model.eval()
    total_ssim = 0.0
    device = next(model.parameters()).device
    for i in range(min(num_samples, len(dataset))):
        inp, gt = dataset[i]
        inp_tensor = inp.unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(inp_tensor).cpu().squeeze(0).permute(1, 2, 0).numpy()
        gt = gt.permute(1, 2, 0).numpy()
        ssim_val = ssim(gt, out, channel_axis=2, data_range=1.0)
        total_ssim += ssim_val
    return total_ssim / min(num_samples, len(dataset))

In [None]:
# Evaluate Teacher SSIM
teacher_model.eval()
teacher_ssim = average_ssim(teacher_model, dataset, num_samples=50)
print(f"Teacher Average SSIM: {teacher_ssim:.4f}")


Teacher Average SSIM: 0.7327


In [None]:
# Student Model
class StudentCNN(nn.Module):
    def __init__(self):
        super(StudentCNN, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, 3, padding=1)
        )
    def forward(self, x):
        return x - self.body(x)

In [None]:
# Distillation Loss = alpha * MSE(Student, Teacher) + (1-alpha) * MSE(Student, GroundTruth)
def distillation_loss(student_out, teacher_out, target, alpha=0.7):
    loss_gt = nn.MSELoss()(student_out, target)
    loss_teacher = nn.MSELoss()(student_out, teacher_out)
    return alpha * loss_teacher + (1 - alpha) * loss_gt


In [None]:
# Train Student Model Using Teacher
teacher_model.eval()
student_model = StudentCNN().to(device)
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
epochs = 10

for epoch in range(epochs):
    student_model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            teacher_out = teacher_model(inputs)
        student_out = student_model(inputs)
        loss = distillation_loss(student_out, teacher_out, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"[Student] Epoch {epoch+1}/{epochs} | Loss: {running_loss/len(train_loader):.4f}")

torch.save(student_model.state_dict(), "student_dncnn.pth")


[Student] Epoch 1/10 | Loss: 0.0034
[Student] Epoch 2/10 | Loss: 0.0030
[Student] Epoch 3/10 | Loss: 0.0029
[Student] Epoch 4/10 | Loss: 0.0029
[Student] Epoch 5/10 | Loss: 0.0029
[Student] Epoch 6/10 | Loss: 0.0029
[Student] Epoch 7/10 | Loss: 0.0029
[Student] Epoch 8/10 | Loss: 0.0029
[Student] Epoch 9/10 | Loss: 0.0029
[Student] Epoch 10/10 | Loss: 0.0029
Student model saved.


In [None]:
# Student SSIM
student_model.eval()
student_ssim = average_ssim(student_model, dataset, num_samples=50)
print(f"Student Average SSIM: {student_ssim:.4f}")


Student Average SSIM: 0.7386
