In [None]:
# ======================= Environment Setup & Global Configuration =======================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms.functional as TF
from tqdm import tqdm
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device being used: {device}")

SCALE_FACTOR = 4     
BATCH_SIZE = 16    
LEARNING_RATE = 1e-4 
NUM_EPOCHS = 50


Device being used: cuda


In [None]:
# ======================= Smart Super-Resolution Dataset & DataLoaders =======================
class SRDataset(Dataset):
    def __init__(self, root_dir, split='train'):
        self.hr_dir = os.path.join(root_dir, 'HR', split)
        self.lr_dir = os.path.join(root_dir, 'LR', split)

        self.image_filenames = [
            x for x in os.listdir(self.hr_dir)
            if x.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))
        ]

    def __getitem__(self, index):
        hr_filename = self.image_filenames[index]
        hr_path = os.path.join(self.hr_dir, hr_filename)

        name_no_ext = os.path.splitext(hr_filename)[0]

        lr_path = None
        for ext in ['.png', '.jpg', '.jpeg', '.bmp']:
            p = os.path.join(self.lr_dir, name_no_ext + ext)
            if os.path.exists(p):
                lr_path = p
                break

        if lr_path is None:
            return self.__getitem__(np.random.randint(0, len(self.image_filenames)))

        hr_img = Image.open(hr_path).convert('RGB')
        lr_img = Image.open(lr_path).convert('RGB')

        return TF.to_tensor(lr_img).to(device), TF.to_tensor(hr_img).to(device)

    def __len__(self):
        return len(self.image_filenames)


DATASET_PATH = "/kaggle/input/croped-data/Mini_Dataset_Smart_Kaggle"

train_ds = SRDataset(DATASET_PATH, split='train')
val_ds = SRDataset(DATASET_PATH, split='val')

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

print("Smart Loader Ready!")
print(f"Train Patches: {len(train_ds)}")
print(f"Val Patches: {len(val_ds)}")


Smart Loader Ready!
Train Patches: 31180
Val Patches: 3890


In [None]:
# ======================= SRResNet Generator Architecture (Residual Learning + PixelShuffle Upscaling) =======================

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)
        return x + residual


class UpsampleBlock(nn.Module):
    def __init__(self, channels, scale_factor):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(channels, channels * (scale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x


class SRResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_res_blocks=16):
        super(SRResNet, self).__init__()

        self.conv_input = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )

        self.res_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(n_res_blocks)]
        )

        self.conv_mid = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )

        self.upsample = nn.Sequential(
            UpsampleBlock(64, 2),
            UpsampleBlock(64, 2)
        )

        self.conv_output = nn.Conv2d(64, out_channels, kernel_size=9, padding=4)

    def forward(self, x):
        out1 = self.conv_input(x)
        out = self.res_blocks(out1)
        out = self.conv_mid(out)
        out = out + out1
        out = self.upsample(out)
        out = self.conv_output(out)
        return out


model = SRResNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.L1Loss()

print("Model created successfully!")


Model created successfully!


In [None]:
# ======================= SRResNet Training Loop (L1 Loss + PSNR Monitoring + Best Model Checkpoint) =======================

def train_model():
    print("Starting Training...")
    best_loss = float('inf')

    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0
        train_psnr = 0

        loop = tqdm(train_loader, leave=True)

        for lr_imgs, hr_imgs in loop:
            optimizer.zero_grad()

            sr_imgs = model(lr_imgs)
            loss = criterion(sr_imgs, hr_imgs)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            with torch.no_grad():
                mse = nn.MSELoss()(sr_imgs, hr_imgs)
                if mse == 0:
                    psnr = 100
                else:
                    psnr = 10 * torch.log10(1 / mse)
                train_psnr += psnr.item()

            loop.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
            loop.set_postfix(loss=loss.item(), psnr=psnr.item())

        avg_loss = train_loss / len(train_loader)
        avg_psnr = train_psnr / len(train_loader)

        print(f" -> Epoch {epoch+1} | Avg Loss: {avg_loss:.5f} | Avg PSNR: {avg_psnr:.2f} dB")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), "best_sr_model.pth")
            print(" -> Model Saved!")


if __name__ == "__main__":
    train_model()

Starting Training...


Epoch [1/50]: 100%|██████████| 1949/1949 [15:06<00:00,  2.15it/s, loss=0.0307, psnr=26]  


 -> Epoch 1 | Avg Loss: 0.04509 | Avg PSNR: 23.94 dB
 -> Model Saved!


Epoch [2/50]: 100%|██████████| 1949/1949 [11:04<00:00,  2.94it/s, loss=0.0329, psnr=25]  


 -> Epoch 2 | Avg Loss: 0.03201 | Avg PSNR: 25.92 dB
 -> Model Saved!


Epoch [3/50]: 100%|██████████| 1949/1949 [10:53<00:00,  2.98it/s, loss=0.0295, psnr=25.6]


 -> Epoch 3 | Avg Loss: 0.02985 | Avg PSNR: 26.38 dB
 -> Model Saved!


Epoch [4/50]: 100%|██████████| 1949/1949 [11:01<00:00,  2.95it/s, loss=0.0244, psnr=28.1]


 -> Epoch 4 | Avg Loss: 0.02876 | Avg PSNR: 26.60 dB
 -> Model Saved!


Epoch [5/50]: 100%|██████████| 1949/1949 [10:43<00:00,  3.03it/s, loss=0.0394, psnr=24.6]


 -> Epoch 5 | Avg Loss: 0.02824 | Avg PSNR: 26.73 dB
 -> Model Saved!


Epoch [6/50]: 100%|██████████| 1949/1949 [10:52<00:00,  2.99it/s, loss=0.0365, psnr=23]  


 -> Epoch 6 | Avg Loss: 0.02783 | Avg PSNR: 26.83 dB
 -> Model Saved!


Epoch [7/50]: 100%|██████████| 1949/1949 [10:58<00:00,  2.96it/s, loss=0.0248, psnr=27.4]


 -> Epoch 7 | Avg Loss: 0.02755 | Avg PSNR: 26.88 dB
 -> Model Saved!


Epoch [8/50]: 100%|██████████| 1949/1949 [10:58<00:00,  2.96it/s, loss=0.0182, psnr=29.9]


 -> Epoch 8 | Avg Loss: 0.02724 | Avg PSNR: 26.97 dB
 -> Model Saved!


Epoch [9/50]: 100%|██████████| 1949/1949 [11:29<00:00,  2.83it/s, loss=0.0257, psnr=28.2]


 -> Epoch 9 | Avg Loss: 0.02717 | Avg PSNR: 26.98 dB
 -> Model Saved!


Epoch [10/50]: 100%|██████████| 1949/1949 [11:39<00:00,  2.79it/s, loss=0.0256, psnr=27.3]


 -> Epoch 10 | Avg Loss: 0.02689 | Avg PSNR: 27.05 dB
 -> Model Saved!


Epoch [11/50]: 100%|██████████| 1949/1949 [11:24<00:00,  2.85it/s, loss=0.0221, psnr=28.7]


 -> Epoch 11 | Avg Loss: 0.02662 | Avg PSNR: 27.08 dB
 -> Model Saved!


Epoch [12/50]: 100%|██████████| 1949/1949 [11:18<00:00,  2.87it/s, loss=0.0345, psnr=25]  


 -> Epoch 12 | Avg Loss: 0.02659 | Avg PSNR: 27.09 dB
 -> Model Saved!


Epoch [13/50]: 100%|██████████| 1949/1949 [11:03<00:00,  2.94it/s, loss=0.0394, psnr=22.1]


 -> Epoch 13 | Avg Loss: 0.02638 | Avg PSNR: 27.14 dB
 -> Model Saved!


Epoch [14/50]: 100%|██████████| 1949/1949 [11:10<00:00,  2.91it/s, loss=0.0283, psnr=26.5]


 -> Epoch 14 | Avg Loss: 0.02632 | Avg PSNR: 27.15 dB
 -> Model Saved!


Epoch [15/50]: 100%|██████████| 1949/1949 [11:12<00:00,  2.90it/s, loss=0.0363, psnr=23.8]


 -> Epoch 15 | Avg Loss: 0.02622 | Avg PSNR: 27.18 dB
 -> Model Saved!


Epoch [16/50]: 100%|██████████| 1949/1949 [11:07<00:00,  2.92it/s, loss=0.0332, psnr=25.6]


 -> Epoch 16 | Avg Loss: 0.02615 | Avg PSNR: 27.21 dB
 -> Model Saved!


Epoch [17/50]: 100%|██████████| 1949/1949 [11:04<00:00,  2.93it/s, loss=0.0222, psnr=29]  


 -> Epoch 17 | Avg Loss: 0.02606 | Avg PSNR: 27.21 dB
 -> Model Saved!


Epoch [18/50]: 100%|██████████| 1949/1949 [10:53<00:00,  2.98it/s, loss=0.0297, psnr=24.6]


 -> Epoch 18 | Avg Loss: 0.02589 | Avg PSNR: 27.26 dB
 -> Model Saved!


Epoch [19/50]: 100%|██████████| 1949/1949 [10:52<00:00,  2.99it/s, loss=0.0216, psnr=29]  


 -> Epoch 19 | Avg Loss: 0.02598 | Avg PSNR: 27.26 dB


Epoch [20/50]: 100%|██████████| 1949/1949 [10:59<00:00,  2.95it/s, loss=0.0213, psnr=30.3]


 -> Epoch 20 | Avg Loss: 0.02584 | Avg PSNR: 27.24 dB
 -> Model Saved!


Epoch [21/50]: 100%|██████████| 1949/1949 [10:54<00:00,  2.98it/s, loss=0.0271, psnr=26.9]


 -> Epoch 21 | Avg Loss: 0.02577 | Avg PSNR: 27.27 dB
 -> Model Saved!


Epoch [22/50]:  84%|████████▍ | 1636/1949 [08:51<01:45,  2.96it/s, loss=0.0336, psnr=25]  

In [None]:
# ======================= Model Export =======================

import torch

device = torch.device("cpu")
model = SRResNet().to(device)
model.load_state_dict(torch.load("/kaggle/working/best_sr_model.pth", map_location=device))
model.eval()

dummy_input = torch.rand(1, 3, 48, 48).to(device)
traced_model = torch.jit.trace(model, dummy_input)

traced_model.save("final_model.pt")

print("Saved")


In [None]:
# ======================= Inference =======================

import torch
from PIL import Image
import torchvision.transforms.functional as TF

MODEL_PATH = "/kaggle/working/srresnet_scripted.pt"
INPUT_IMAGE = "/kaggle/input/croped-data/Mini_Dataset_Smart_Kaggle/LR/test/3_160_aug.png"
OUTPUT_IMAGE = "/kaggle/working/3_160_SR.png"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = torch.jit.load(MODEL_PATH, map_location=device)
model.eval()

img = Image.open(INPUT_IMAGE).convert("RGB")
lr_tensor = TF.to_tensor(img).unsqueeze(0).to(device)

with torch.no_grad():
    sr_tensor = model(lr_tensor)

sr_tensor = sr_tensor.clamp(0, 1)
sr_img = TF.to_pil_image(sr_tensor.squeeze(0))
sr_img.save(OUTPUT_IMAGE)

print("Super Resolution image saved at:", OUTPUT_IMAGE)


Using device: cuda
 Super Resolution image saved at: /kaggle/working/3_160_SR.png


In [None]:
# =================================== FINE TUNING =======================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from tqdm import tqdm

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# normalize for VGG19
def normalize_vgg(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
    return (tensor - mean) / std


In [None]:
class SRDataset(Dataset):
    # =================================== INITIALIZATION ===================================
    def __init__(self, root_dir, split='train'):
        self.lr_dir = os.path.join(root_dir, 'LR', split)
        self.hr_dir = os.path.join(root_dir, 'HR', split)

        if not os.path.exists(self.lr_dir) or not os.path.exists(self.hr_dir):
            raise FileNotFoundError(f"Path not found: {root_dir}")
        self.images = sorted([f for f in os.listdir(self.lr_dir) 
                              if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

    # =================================== LENGTH ===================================
    def __len__(self):
        return len(self.images)

    # =================================== GET ITEM ===================================
    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.images[idx])
        hr_path = os.path.join(self.hr_dir, self.images[idx])
        lr = Image.open(lr_path).convert('RGB')
        hr = Image.open(hr_path).convert('RGB')
        return TF.to_tensor(lr), TF.to_tensor(hr)


In [None]:
DATASET_PATH = "/kaggle/input/croped-data/Mini_Dataset_Smart_Kaggle"
BATCH_SIZE = 16

# =================================== DATASET ===================================
train_ds = SRDataset(DATASET_PATH, split='train')  
val_ds   = SRDataset(DATASET_PATH, split='val')

# =================================== DATALOADERS ===================================
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Loaders are ready! Train samples: {len(train_ds)}")


Loaders are ready! Train samples: 31180


In [None]:
# =================================== RESIDUAL BLOCK ===================================
class ResidualBlock(nn.Module):
    """Basic residual block with two conv layers and skip connection."""
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)  # Skip connection

# =================================== SRRESNET MODEL ===================================
class SRResNet(nn.Module):
    """Super-Resolution Residual Network."""
    def __init__(self, in_channels=3, out_channels=3, num_feat=64, num_blocks=16):
        super().__init__()
        self.input_conv = nn.Sequential(nn.Conv2d(in_channels, num_feat, 9, 1, 4), nn.PReLU())
        self.res_blocks = nn.Sequential(*[ResidualBlock(num_feat) for _ in range(num_blocks)])
        self.mid_conv = nn.Sequential(nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.BatchNorm2d(num_feat))
        self.upsample = nn.Sequential(
            nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1), nn.PixelShuffle(2), nn.PReLU(),
            nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1), nn.PixelShuffle(2), nn.PReLU()
        )
        self.output_conv = nn.Conv2d(num_feat, out_channels, 9, 1, 4)

    def forward(self, x):
        x_in = self.input_conv(x)
        res = self.res_blocks(x_in)
        x = self.mid_conv(res) + x_in  # Residual connection
        return self.output_conv(self.upsample(x))  # Final SR output

# =================================== DISCRIMINATOR ===================================
class Discriminator(nn.Module):
    """PatchGAN-like discriminator for adversarial training."""
    def __init__(self):
        super().__init__()

        def block(in_c, out_c, stride):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, stride, 1),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2)
            )

        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1), nn.LeakyReLU(0.2),
            block(64, 64, 2), block(64, 128, 1), block(128, 128, 2),
            block(128, 256, 1), block(256, 256, 2),
            nn.AdaptiveAvgPool2d(1), nn.Conv2d(256, 1, 1)
        )

    def forward(self, x):
        return torch.sigmoid(self.net(x)).view(-1)  # Output probability per image


In [None]:
# =================================== PERCEPTUAL LOSS ===================================
class VGGPerceptualLoss(nn.Module):
    """Perceptual loss using pre-trained VGG19 features."""
    def __init__(self):
        super().__init__()
        vgg = vgg19(weights='DEFAULT').features[:22].eval()
        for p in vgg.parameters(): 
            p.requires_grad = False
        self.vgg = vgg.to(device)

    def forward(self, sr, hr):
        return F.mse_loss(self.vgg(normalize_vgg(sr)), self.vgg(normalize_vgg(hr)))

# =================================== TOTAL VARIATION LOSS ===================================
class TVLoss(nn.Module):
    """Total Variation Loss to enforce spatial smoothness."""
    def forward(self, x):
        h_tv = torch.pow((x[:,:,1:,:] - x[:,:,:-1,:]), 2).sum() 
        w_tv = torch.pow((x[:,:,:,1:] - x[:,:,:,:-1]), 2).sum()
        return (h_tv + w_tv) / (x.size(0) * x.size(1) * x.size(2) * x.size(3))

# LOSS WEIGHTS
λ_pix = 0.1 
λ_per = 0.02  
λ_adv = 0.005 
λ_tv  = 1e-9    

# LOSS FUNCTIONS
pixel_loss_fn = nn.L1Loss()        
vgg_loss_fn   = VGGPerceptualLoss()  
adv_loss_fn   = nn.BCELoss()     
tv_loss_fn    = TVLoss()    

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 250MB/s] 


In [6]:
model = SRResNet().to(device)
discriminator = Discriminator().to(device)

SMOOTH_MODEL = "/kaggle/input/best/pytorch/default/1/best_sr_model (1).pth" 
if os.path.exists(SMOOTH_MODEL):
    model.load_state_dict(torch.load(SMOOTH_MODEL, map_location=device), strict=False)
    print("Pre-trained smooth model loaded!")

opt_G = torch.optim.Adam(model.parameters(), lr=5e-5)
opt_D = torch.optim.Adam(discriminator.parameters(), lr=5e-5)

Pre-trained smooth model loaded!


In [None]:
# =================================== TRAINING CONFIG ===================================
EPOCHS = 10
SAVE_DIR = "/kaggle/working/models"
os.makedirs(SAVE_DIR, exist_ok=True)

# =================================== FINE-TUNING LOOP ===================================
for epoch in range(EPOCHS):
    model.train()
    discriminator.train()
    loop = tqdm(train_loader)
    
    for lr_imgs, hr_imgs in loop:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

        # ------------------- TRAIN DISCRIMINATOR -------------------
        fake_sr = model(lr_imgs)
        d_loss = adv_loss_fn(discriminator(hr_imgs), torch.ones(lr_imgs.size(0)).to(device)) + \
                 adv_loss_fn(discriminator(fake_sr.detach()), torch.zeros(lr_imgs.size(0)).to(device))
        
        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # ------------------- TRAIN GENERATOR -------------------
        for _ in range(2):
            opt_G.zero_grad()
            
            fake_sr = model(lr_imgs)
            
            loss_pix = pixel_loss_fn(fake_sr, hr_imgs)
            loss_per = vgg_loss_fn(fake_sr, hr_imgs) 
            loss_adv = adv_loss_fn(discriminator(fake_sr), torch.ones(lr_imgs.size(0)).to(device))
            loss_tv  = tv_loss_fn(fake_sr)        

            g_loss = (λ_pix * loss_pix) + (λ_per * loss_per) + (λ_adv * loss_adv) + (λ_tv * loss_tv)

            g_loss.backward()
            opt_G.step()

        loop.set_description(f"Epoch [{epoch+1}/{EPOCHS}]")
        loop.set_postfix(G_loss=f"{g_loss.item():.4f}", D_loss=f"{d_loss.item():.4f}")

    # Save model checkpoint per epoch
    torch.save(model.state_dict(), f"{SAVE_DIR}/srresnet_skin_focus_epoch_{epoch+1}.pth")

print("Fine-tuning complete with focus on skin texture!")

Epoch [1/10]: 100%|██████████| 1949/1949 [40:50<00:00,  1.26s/it, D_loss=1.1143, G_loss=0.2712]
Epoch [2/10]: 100%|██████████| 1949/1949 [41:05<00:00,  1.26s/it, D_loss=0.0909, G_loss=0.5479]
Epoch [3/10]: 100%|██████████| 1949/1949 [41:01<00:00,  1.26s/it, D_loss=0.1887, G_loss=0.2788]
Epoch [4/10]: 100%|██████████| 1949/1949 [41:02<00:00,  1.26s/it, D_loss=0.1807, G_loss=0.3198]
Epoch [5/10]: 100%|██████████| 1949/1949 [41:06<00:00,  1.27s/it, D_loss=0.1904, G_loss=0.3612]
Epoch [6/10]: 100%|██████████| 1949/1949 [41:13<00:00,  1.27s/it, D_loss=1.1777, G_loss=0.3439]
Epoch [7/10]: 100%|██████████| 1949/1949 [41:14<00:00,  1.27s/it, D_loss=0.2100, G_loss=0.4219]
Epoch [8/10]: 100%|██████████| 1949/1949 [41:05<00:00,  1.27s/it, D_loss=0.8456, G_loss=0.2652]
Epoch [9/10]: 100%|██████████| 1949/1949 [41:06<00:00,  1.27s/it, D_loss=0.2340, G_loss=0.3074]
Epoch [10/10]: 100%|██████████| 1949/1949 [41:06<00:00,  1.27s/it, D_loss=0.2691, G_loss=0.2373]

Fine-tuning complete with focus on skin texture!





In [None]:
import torch

LAST_EPOCH_PATH = "/kaggle/working/models/srresnet_skin_focus_epoch_10.pth"

model = SRResNet().to(device)
model.load_state_dict(torch.load(LAST_EPOCH_PATH, map_location=device))
model.eval()

dummy_input = torch.randn(1, 3, 48, 48).to(device)
scripted_model = torch.jit.trace(model, dummy_input)
scripted_model.save("/kaggle/working/model_fine.pt")

print("done")

done


In [None]:
import torch
import gc
import numpy as np
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric

# =================================== MEMORY CLEANUP ===================================
torch.cuda.empty_cache()
gc.collect()

# =================================== METRICS STORAGE ===================================
psnr_values = []
ssim_values = []

print(f"Evaluating model with memory management on {len(test_ds)} images...")

# =================================== EVALUATION LOOP ===================================
with torch.no_grad():
    for lr_img, hr_img in tqdm(test_loader):
        lr_img = lr_img.to(device)
        
        try:
            sr_img = model(lr_img).clamp(0, 1).cpu()
            
            sr_np = sr_img.squeeze(0).permute(1, 2, 0).numpy()
            hr_np = hr_img.squeeze(0).permute(1, 2, 0).numpy()
            
            psnr_values.append(psnr_metric(hr_np, sr_np, data_range=1.0))
            ssim_values.append(ssim_metric(hr_np, sr_np, data_range=1.0, channel_axis=2))
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"Image too large for GPU, skipping or try resizing.")
                torch.cuda.empty_cache()
                continue
            else:
                raise e
        
        # Force cleanup after each image to free VRAM
        del lr_img
        if 'sr_img' in locals(): del sr_img
        torch.cuda.empty_cache()
        gc.collect()

# =================================== RESULTS ===================================
if psnr_values:
    print("\n" + "="*30)
    print(f"Average PSNR: {np.mean(psnr_values):.2f} dB")
    print(f"Average SSIM: {np.mean(ssim_values):.4f}")
    print("="*30)
else:
    print("No images were processed due to memory limits.")


Evaluating model with memory management on 27 images...


  return forward_call(*args, **kwargs)
 22%|██▏       | 6/27 [00:33<01:58,  5.64s/it]

Image too large for GPU, skipping or try resizing.


 30%|██▉       | 8/27 [00:51<02:09,  6.82s/it]

Image too large for GPU, skipping or try resizing.


100%|██████████| 27/27 [03:50<00:00,  8.54s/it]


Average PSNR: 28.51 dB
Average SSIM: 0.8221



