In [None]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import Compose, RandomCrop, GaussianBlur, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, ColorJitter, ToTensor, RandomApply, RandomPerspective, RandomAffine
import torch.nn.functional as F
import random
from torch.utils.data import Dataset
import math


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout2d(p=0.3) 
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x) 
        x = self.conv2(x)
        x = self.bn2(x)
        return x + residual

class SRNN(nn.Module):
    def __init__(self):
        super(SRNN, self).__init__()
        self.input_conv = nn.Conv2d(3, 32, kernel_size=9, padding=4)
        self.residual_blocks = nn.Sequential(*[ResidualBlock(32) for _ in range(5)])
        self.channel_reduction = nn.Conv2d(32, 16, kernel_size=3, padding=1)

        # 4x upsampling blocks (2x2x2x2)
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.output_conv = nn.Conv2d(16, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x_input = F.interpolate(x, scale_factor=16, mode='bicubic', align_corners=False)
        x = self.relu(self.input_conv(x))
        x = self.residual_blocks(x)
        x = self.relu(self.channel_reduction(x))
        x = self.upsample(x)
        x = self.output_conv(x)
        return x + x_input


class SRDataset(Dataset):
    def __init__(self, img_path, crop_size=64, scale=16, transform=None):
        self.hr_image = Image.open(img_path).convert('RGB')
        self.crop_size = crop_size
        self.scale = scale
        self.transform = transform or Compose([
            RandomHorizontalFlip(),
            RandomVerticalFlip(),
            RandomRotation(degrees=90),
            ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3),
            RandomApply([GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.5),
            RandomPerspective(distortion_scale=0.5, p=0.5),
            RandomAffine(degrees=15, translate=(0.05, 0.05), scale=(0.95, 1.05))
        ])

    def __len__(self):
        return 3000  

    def __getitem__(self, idx):
        width, height = self.hr_image.size
        scaled_crop = self.crop_size * self.scale

        left = random.randint(0, width - scaled_crop)
        top = random.randint(0, height - scaled_crop)
        hr_crop = self.hr_image.crop((left, top, left + scaled_crop, top + scaled_crop))
        hr_crop = self.transform(hr_crop)

        lr_crop = hr_crop.resize((self.crop_size, self.crop_size), Image.BICUBIC)

        return ToTensor()(lr_crop), ToTensor()(hr_crop)

def calculate_psnr(output, target, max_pixel_value=1.0):
    mse = torch.mean((output - target) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 10 * math.log10(max_pixel_value ** 2 / mse.item())
    return psnr


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRNN().to(device)
# model.load_state_dict(torch.load('srnn_model.pth'))
# model.eval()

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

train_dataset = SRDataset('/kaggle/input/datasetscale16v2/1.jpg', crop_size=32, scale=16)
val_dataset = SRDataset('/kaggle/input/datasetscale16v2/2.jpg', crop_size=32, scale=16)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

best_psnr = 0.0

for epoch in range(200):
    model.train()
    total_loss = 0.0
    total_psnr = 0.0

    for lr, hr in train_loader:
        lr, hr = lr.to(device), hr.to(device)
        optimizer.zero_grad()
        sr = model(lr)
        loss = criterion(sr, hr)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * lr.size(0)
        total_psnr += calculate_psnr(sr, hr) * lr.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    avg_psnr = total_psnr / len(train_loader.dataset)
    scheduler.step()

    model.eval()
    val_loss = 0.0
    val_psnr = 0.0
    with torch.no_grad():
        for lr, hr in val_loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)
            loss = criterion(sr, hr)
            val_loss += loss.item() * lr.size(0)
            val_psnr += calculate_psnr(sr, hr) * lr.size(0)

    val_loss /= len(val_loader.dataset)
    val_psnr /= len(val_loader.dataset)

    if val_psnr > best_psnr:
        best_psnr = val_psnr
        torch.save(model.state_dict(), '/kaggle/working/best_srnn_model.pth')
        print(f"[Epoch {epoch+1}] Model salvat (Val PSNR: {val_psnr:.2f} dB)")

    print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f} | Train PSNR: {avg_psnr:.2f} dB | Val PSNR: {val_psnr:.2f} dB")

torch.save(model.state_dict(), '/kaggle/working/srnn_model.pth')
