In [None]:
!pip install basicsr gdown kaggle torch torchvision scikit-image --quiet

!sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' /usr/local/lib/python3.*/dist-packages/basicsr/data/degradations.py

!git clone https://github.com/xinntao/BasicSR
%cd BasicSR
!python setup.py develop
%cd /content

from google.colab import files
print("⬆️ Upload kaggle.json from your Kaggle account")
files.upload()  # Upload kaggle.json
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d giancarlocuticchia/pretrained-edsr-x4-models
!unzip -q pretrained-edsr-x4-models.zip -d edsr_weights

print("⬆️ Upload your high-res training images")
uploaded_train = files.upload()
train_dir = "/content/images"
import os, shutil
os.makedirs(train_dir, exist_ok=True)
for fname in uploaded_train:
    if fname.endswith(('.png', '.jpg')):
        shutil.move(fname, os.path.join(train_dir, fname))

print("⬆️ Upload your high-res test images")
uploaded_test = files.upload()
test_dir = "/content/test_images"
os.makedirs(test_dir, exist_ok=True)
for fname in uploaded_test:
    if fname.endswith(('.png', '.jpg')):
        shutil.move(fname, os.path.join(test_dir, fname))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from torchvision.io import read_image
from skimage.metrics import structural_similarity as ssim_metric
import glob

class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 3, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv3(x))
        return x

from basicsr.archs.edsr_arch import EDSR

teacher = EDSR(num_in_ch=3, num_out_ch=3, upscale=4)
state_dict = torch.load("/content/edsr_weights/edsr_x4-4f62e9ef.pt", map_location=torch.device('cpu'))
teacher.load_state_dict(state_dict, strict=False)
teacher.eval()

class BlurDataset(Dataset):
    def __init__(self, img_dir):
        self.paths = glob.glob(os.path.join(img_dir, '*.png')) + \
                     glob.glob(os.path.join(img_dir, '*.jpg'))
        self.transform = T.Compose([
            T.ConvertImageDtype(torch.float32),
            T.Resize((256, 256))
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = read_image(self.paths[idx])
        img = self.transform(img)
        blurred = T.Resize((64, 64), interpolation=T.InterpolationMode.BICUBIC)(img)
        blurred = T.Resize((256, 256), interpolation=T.InterpolationMode.BICUBIC)(blurred)
        return blurred, img

from torchvision.transforms.functional import resize

def ssim(img1, img2):
    img1 = img1.squeeze().permute(1, 2, 0).cpu().numpy()
    img2 = img2.squeeze().permute(1, 2, 0).cpu().numpy()
    return ssim_metric(img1, img2, data_range=1.0, channel_axis=2)

def psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return 100
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def kd_loss(student_out, gt, teacher_out):
    teacher_out_resized = resize(teacher_out, size=student_out.shape[-2:])
    return F.mse_loss(student_out, gt) + 0.1 * F.mse_loss(student_out, teacher_out_resized)

BATCH_SIZE = 4
EPOCHS = 10
LR = 1e-4
SAVE_DIR = "/content/results"
os.makedirs(SAVE_DIR, exist_ok=True)

def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = StudentNet().to(device)
    teacher_model = teacher.to(device)
    teacher_model.eval()

    dataset = BlurDataset(train_dir)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for blurred, sharp in loader:
            blurred, sharp = blurred.to(device), sharp.to(device)
            with torch.no_grad():
                teacher_out = teacher_model(blurred)
            student_out = model(blurred)
            loss = kd_loss(student_out, sharp, teacher_out)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(loader):.4f}")

    torch.save(model.state_dict(), "/content/student_model.pth")
    print("\n✅ Training Complete")

def evaluate():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = StudentNet().to(device)
    model.load_state_dict(torch.load("/content/student_model.pth", map_location=device))
    model.eval()

    dataset = BlurDataset(test_dir)
    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    total_ssim, total_psnr, count = 0, 0, 0

    with torch.no_grad():
        for i, (blurred, sharp) in enumerate(loader):
            blurred, sharp = blurred.to(device), sharp.to(device)
            out = model(blurred)
            # Safety resize
            sharp = resize(sharp, [256, 256])
            out = resize(out, [256, 256])
            score_ssim = ssim(out[0], sharp[0])
            score_psnr = psnr(out, sharp)
            total_ssim += score_ssim
            total_psnr += score_psnr
            count += 1
            save_image(out, f"{SAVE_DIR}/output_{i}.png")

    print(f"\n📊 Average SSIM: {total_ssim / count * 100:.2f}%")
    print(f"📊 Average PSNR: {total_psnr / count:.2f} dB")

train()
evaluate()