In [None]:
import os, torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

WORK_DIR = "/content/work_fundus_sr"
MODEL_PATH = os.path.join(WORK_DIR, "final_model.pt")

class ResBlock(torch.nn.Module):
    def __init__(self, c):
        super().__init__()
        self.body = torch.nn.Sequential(
            torch.nn.Conv2d(c, c, 3, padding=1),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(c, c, 3, padding=1)
        )
    def forward(self, x):
        return x + self.body(x) * 0.1

class EDSR_Lite(torch.nn.Module):
    def __init__(self, scale=2, n_res=8, c=64):
        super().__init__()
        self.head = torch.nn.Conv2d(3, c, 3, padding=1)
        self.body = torch.nn.Sequential(*[ResBlock(c) for _ in range(n_res)])
        self.tail = torch.nn.Sequential(
            torch.nn.Conv2d(c, c * scale * scale, 3, padding=1),
            torch.nn.PixelShuffle(scale),
            torch.nn.Conv2d(c, 3, 3, padding=1)
        )
    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        x = x + res
        return self.tail(x)

ckpt = torch.load(MODEL_PATH, map_location="cpu")
scale = ckpt["cfg"]["scale"]
model = EDSR_Lite(scale=scale, n_res=8, c=64)
model.load_state_dict(ckpt["model"], strict=True)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()

def super_resolve(img):
    x = to_tensor(img).unsqueeze(0).to(device)
    with torch.no_grad():
        y = model(x).clamp(0,1)
    return to_pil(y.squeeze(0).cpu())

from google.colab import files
uploaded = files.upload()
for name in uploaded.keys():
    img = Image.open(name).convert("RGB")
    sr_img = super_resolve(img)
    print(f"Original: {img.size}, SR: {sr_img.size}")
    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1); plt.imshow(img); plt.title(f"Original {img.size}"); plt.axis("off")
    plt.subplot(1,2,2); plt.imshow(sr_img); plt.title(f"Super-Resolution {sr_img.size}"); plt.axis("off")
    plt.show()
