In [2]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import ToTensor 
from torchvision.transforms import Compose,RandomApply, RandomAffine, RandomGrayscale, RandomPerspective, RandomCrop, GaussianBlur, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, ColorJitter, ToTensor
import torch.nn.functional as F
import random
from torch.utils.data import Dataset
from PIL import Image
import math


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x + residual

class SRNN(nn.Module):
    def __init__(self):
        super(SRNN, self).__init__()
        self.input_conv = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(16)]  
        )
        self.channel_reduction = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.output_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x_input = F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False)
        x = self.relu(self.input_conv(x))
        x = self.residual_blocks(x)
        x = self.channel_reduction(x)
        x = self.upsample(x)
        x = self.output_conv(x)
        return x + x_input
    

class SRDataset(Dataset):
    def __init__(self, img_path, crop_size=64, scale=4, transform=None):
        self.hr_image = Image.open(img_path).convert('RGB')
        self.crop_size = crop_size
        self.scale = scale
        self.transform = transform or Compose([
            RandomHorizontalFlip(),
            RandomVerticalFlip(),
            RandomRotation(degrees=135),  
            ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.4),
            RandomApply([GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.1),
        ])

    def __len__(self):
        return 5000  

    def __getitem__(self, idx):
        width, height = self.hr_image.size
        scaled_crop = self.crop_size * self.scale

        left = random.randint(0, width - scaled_crop)
        top = random.randint(0, height - scaled_crop)
        hr_crop = self.hr_image.crop((left, top, left + scaled_crop, top + scaled_crop))
        hr_crop = self.transform(hr_crop)

        lr_crop = hr_crop.resize((self.crop_size, self.crop_size), Image.BICUBIC)

        return ToTensor()(lr_crop), ToTensor()(hr_crop)

def calculate_psnr(output, target, max_pixel_value=1.0):
    mse = torch.mean((output - target) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 10 * math.log10(max_pixel_value ** 2 / mse.item())
    return psnr



In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRNN().to(device)

state_dict = torch.load('/kaggle/input/model4-sr/pytorch/default/1/model4.pth', weights_only=True)
model.load_state_dict(state_dict)
model.eval()

criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)


train_dataset = SRDataset('/kaggle/input/sr-dataset/dataset/train/1.jpg', crop_size=64)
val_dataset = SRDataset('/kaggle/input/sr-dataset/dataset/validation/2.jpg', crop_size=64)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)


best_psnr = 0.0

# Antrenare
for epoch in range(200):
    model.train()
    total_loss = 0.0
    total_psnr = 0.0

    for lr, hr in train_loader:
        lr, hr = lr.to(device), hr.to(device)
        optimizer.zero_grad()
        sr = model(lr)
        loss = criterion(sr, hr)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * lr.size(0)
        total_psnr += calculate_psnr(sr, hr) * lr.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    avg_psnr = total_psnr / len(train_loader.dataset)
    scheduler.step()

    # Validare
    model.eval()
    val_loss = 0.0
    val_psnr = 0.0
    with torch.no_grad():
        for lr, hr in val_loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)
            loss = criterion(sr, hr)
            val_loss += loss.item() * lr.size(0)
            val_psnr += calculate_psnr(sr, hr) * lr.size(0)

    val_loss /= len(val_loader.dataset)
    val_psnr /= len(val_loader.dataset)

    if val_psnr > best_psnr:
        best_psnr = val_psnr
        torch.save(model.state_dict(), 'best_srnn_model.pth')
        print(f"[Epoch {epoch+1}] Model salvat (Val PSNR: {val_psnr:.2f} dB)")

    print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f} | Train PSNR: {avg_psnr:.2f} dB | Val PSNR: {val_psnr:.2f} dB")

torch.save(model.state_dict(), 'srnn_model.pth')

In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
import torch
import torch.nn as nn
import os
import torch
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
import torch.nn as nn
import math
import os

def calculate_psnr(output, target, max_pixel_value=1.0):
    mse = torch.mean((output - target) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 10 * math.log10(max_pixel_value ** 2 / mse.item())
    return psnr

def preprocess_full_image(img_path, scale=4):
    hr = Image.open(img_path).convert('RGB')

    w, h = hr.size
    w -= w % scale
    h -= h % scale
    hr = hr.crop((0, 0, w, h))

    lr = hr.resize((w // scale, h // scale), Image.BICUBIC)
    return ToTensor()(lr).unsqueeze(0), ToTensor()(hr).unsqueeze(0)


def test_model_full_image(model_path, test_img_path, output_path, scale=4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = SRNN().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    input_img, target_img = preprocess_full_image(test_img_path, scale=scale)
    input_img = input_img.to(device)
    target_img = target_img.to(device)

    with torch.no_grad():
        output = model(input_img)

    criterion = nn.L1Loss()
    psnr = calculate_psnr(output, target_img)
    loss = criterion(output, target_img).item()

    to_pil = ToPILImage()
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    to_pil(input_img.squeeze(0).cpu()).save("scale4/before.jpg")
    to_pil(output.squeeze(0).cpu()).save(output_path)
    to_pil(target_img.squeeze(0).cpu()).save("scale4/ground_truth.jpg")

    print(f"Output salvat in: {output_path}")
    print(f"PSNR pe imaginea completa: {psnr:.2f} dB")
    print(f"Loss pe imaginea completa: {loss:.4f}")

test_model_full_image(
    model_path='models4-try/model-x.pth',
    test_img_path='dataset/validation/2.jpg',
    output_path='scale4/after.jpg',
    scale=4
)

# Downsampled by 4x -> PSNR: 34.62 dB

Output salvat în: scale4/after.jpg
PSNR pe imaginea completa: 30.02 dB
Loss pe imaginea completa: 0.0137
