In [2]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import ToTensor 
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, ColorJitter, ToTensor


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x + residual

class SRNN(nn.Module):
    def __init__(self):
        super(SRNN, self).__init__()
        self.input_conv = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(5)]
        )
        self.channel_reduction = nn.Conv2d(64, 16, kernel_size=3, padding=1)
        self.upsample = nn.PixelShuffle(upscale_factor=2)
        self.output_conv = nn.Conv2d(4, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.input_conv(x))
        x = self.residual_blocks(x)
        x = self.relu(self.channel_reduction(x))
        x = self.upsample(x)
        x = self.output_conv(x)
        return x
    
def preprocess(img_path, scale=2, augment=False):
    hr = Image.open(img_path).convert('RGB')
    w, h = hr.size
    lr = hr.resize((w // scale, h // scale), Image.BICUBIC)

    if augment:
        augmentation = Compose([
            RandomHorizontalFlip(p=0.5),
            RandomVerticalFlip(p=0.5),
            RandomRotation(degrees=25),
            ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
        ])
        hr = augmentation(hr)
        lr = augmentation(lr)

    return ToTensor()(lr).unsqueeze(0), ToTensor()(hr).unsqueeze(0)


In [None]:
import torch.optim as optim
import math

def calculate_psnr(output, target, max_pixel_value=1.0):
    mse = torch.mean((output - target) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 10 * math.log10(max_pixel_value ** 2 / mse.item())
    return psnr

model = SRNN()
# model.load_state_dict(torch.load('srnn_model.pth'))
# model.eval()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

input_img, target_img = preprocess('dataset/train/1.jpg', augment=True)
val_input, val_target = preprocess('dataset/validation/2.jpg', augment=False)

# update learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Antrenare
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    output = model(input_img)
    loss = criterion(output, target_img)

    loss.backward()
    optimizer.step()
    scheduler.step() 

    # PSNR
    train_psnr = calculate_psnr(output, target_img)
    
    # Validare
    model.eval()
    with torch.no_grad():
        val_output = model(val_input)
        val_psnr = calculate_psnr(val_output, val_target)
    
    # Rezultate
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Train PSNR: {train_psnr:.2f} dB, Val PSNR: {val_psnr:.2f} dB")

# Salvez modelul antrenat
torch.save(model.state_dict(), 'srnn_model.pth')


In [3]:
from torchvision.transforms import  ToPILImage

def test_model(model_path, test_img_path, output_path):

    model = SRNN()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    input_img, target_img = preprocess(test_img_path,scale=2)


    with torch.no_grad():
        output = model(input_img)

    output_img = ToPILImage()(output.squeeze(0))

    input_img = ToPILImage()(input_img.squeeze(0))
    input_img.save("before.jpg")

    output_img.save(output_path)
    print(f"Output saved to {output_path}")

test_model(
    model_path='models2/srnn_model2_scale2.pth', 
    test_img_path='dataset/test/b.jpg', 
    output_path='after.jpg'     
)

Output saved to after.jpg
