In [10]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import ToTensor 
from torchvision.transforms import Compose,RandomCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, ColorJitter, ToTensor
import torch.nn.functional as F
import random

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout2d(p=0.2) 
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x) 
        x = self.conv2(x)
        x = self.bn2(x)
        return x + residual

class SRNN(nn.Module):
    def __init__(self):
        super(SRNN, self).__init__()
        self.input_conv = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(5)]
        )
        self.channel_reduction = nn.Conv2d(64, 16, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.output_conv = nn.Conv2d(16, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x_input = F.interpolate(x, scale_factor=2, mode='bicubic', align_corners=False)
        x = self.relu(self.input_conv(x))
        x = self.residual_blocks(x)
        x = self.relu(self.channel_reduction(x))
        x = self.upsample(x)
        x = self.output_conv(x)
        return x + x_input
    
def extract_augmented_patches(img_path, scale=2, crop_size=64, num_patches=1000):
    hr_img = Image.open(img_path).convert('RGB')
    patches_lr, patches_hr = [], []

    augmentation = Compose([
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        RandomRotation(degrees=30),
        ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
    ])

    width, height = hr_img.size
    scaled_crop = crop_size * scale

    for _ in range(num_patches):
        # Random crop position
        left = random.randint(0, width - scaled_crop)
        top = random.randint(0, height - scaled_crop)
        hr_crop = hr_img.crop((left, top, left + scaled_crop, top + scaled_crop))
        hr_aug = augmentation(hr_crop)

        lr = hr_aug.resize((crop_size, crop_size), Image.BICUBIC)

        patches_lr.append(ToTensor()(lr))
        patches_hr.append(ToTensor()(hr_aug))

    return torch.stack(patches_lr), torch.stack(patches_hr)


In [None]:
import torch.optim as optim
import math


def calculate_psnr(output, target, max_pixel_value=1.0):
    mse = torch.mean((output - target) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 10 * math.log10(max_pixel_value ** 2 / mse.item())
    return psnr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRNN().to(device)
# model.load_state_dict(torch.load('srnn_model.pth'))
# model.eval()

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

train_lr, train_hr = extract_augmented_patches(
    img_path='/kaggle/input/sr-dataset/dataset/train/1.jpg',
    crop_size=64, num_patches=1000
)
val_lr, val_hr = extract_augmented_patches(
    img_path='/kaggle/input/sr-dataset/dataset/validation/2.jpg',
    crop_size=64, num_patches=50
)

train_lr, train_hr = train_lr.to(device), train_hr.to(device)
val_lr, val_hr = val_lr.to(device), val_hr.to(device)


best_psnr = 0.0

# Antrenare
for epoch in range(200):
    model.train()
    total_loss = 0.0
    total_psnr = 0.0

    for i in range(len(train_lr)):
        lr = train_lr[i].unsqueeze(0)
        hr = train_hr[i].unsqueeze(0)

        optimizer.zero_grad()
        sr = model(lr)
        loss = criterion(sr, hr)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_psnr += calculate_psnr(sr, hr)

    avg_loss = total_loss / len(train_lr)
    avg_psnr = total_psnr / len(train_lr)
    scheduler.step()

    # Validare
    model.eval()
    with torch.no_grad():
        val_output = model(val_lr)
        val_loss = criterion(val_output, val_hr).item()
        val_psnr = calculate_psnr(val_output, val_hr)

        if val_psnr > best_psnr:
            best_psnr = val_psnr
            torch.save(model.state_dict(), 'best_srnn_model.pth')
            print(f"[Epoch {epoch+1}] Model salvat (Val PSNR: {val_psnr:.2f} dB)")

    print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f} | Train PSNR: {avg_psnr:.2f} dB | Val PSNR: {val_psnr:.2f} dB")

# Salvează modelul final
torch.save(model.state_dict(), 'srnn_model.pth')


In [12]:
import torch
from torchvision.transforms import ToPILImage

def test_model(model_path, test_img_path, output_path, scale=2):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Încarcă modelul
    model = SRNN().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Preprocesare imagine test
    input_img, target_img = preprocess(test_img_path, scale=scale, augment=False)
    input_img = input_img.to(device)

    # Inferență
    with torch.no_grad():
        output = model(input_img)

    # Conversie la imagine
    output_img = ToPILImage()(output.squeeze(0).cpu())
    input_img = ToPILImage()(input_img.squeeze(0).cpu())
    target_img = ToPILImage()(target_img.squeeze(0).cpu())

    # Salvare rezultate
    input_img.save("before.jpg")
    output_img.save(output_path)
    target_img.save("ground_truth.jpg")

    print(f"Output salvat în: {output_path}")
    print("Input low-res salvat ca: before.jpg")
    print("Ground truth salvat ca: ground_truth.jpg")

# Exemplu apel:
test_model(
    model_path='models2/model2.pth', 
    test_img_path='dataset/test/b.jpg', 
    output_path='after.jpg',
    scale=2
)


Output salvat în: after.jpg
Input low-res salvat ca: before.jpg
Ground truth salvat ca: ground_truth.jpg
