In [1]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import ToTensor 


class SRNN(nn.Module):
    def __init__(self):
        super(SRNN, self).__init__()
        self.layer1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.layer2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
        self.layer3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x)
        return x

def preprocess(img_path, scale=2):
    hr = Image.open(img_path).convert('RGB')
    w, h = hr.size
    lr = hr.resize((w // scale, h // scale), Image.BICUBIC)
    lr_up = lr.resize((w, h), Image.BICUBIC)
    return ToTensor()(lr_up).unsqueeze(0), ToTensor()(hr).unsqueeze(0)

In [None]:
import torch.optim as optim
import math
from PIL import Image
from torchvision.transforms import ToTensor 


def preprocess(img_path, scale=2):
    hr = Image.open(img_path).convert('RGB')
    w, h = hr.size
    lr = hr.resize((w // scale, h // scale), Image.BICUBIC)
    lr_up = lr.resize((w, h), Image.BICUBIC)
    return ToTensor()(lr_up).unsqueeze(0), ToTensor()(hr).unsqueeze(0)


def calculate_psnr(output, target, max_pixel_value=1.0):
    mse = torch.mean((output - target) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 10 * math.log10(max_pixel_value ** 2 / mse.item())
    return psnr

model = SRNN()
# model.load_state_dict(torch.load('srnn_model.pth'))
# model.eval()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

input_img, target_img = preprocess('dataset/train/1.jpg')
val_input, val_target = preprocess('dataset/validation/2.jpg')


# Antrenare
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    output = model(input_img)
    loss = criterion(output, target_img)

    loss.backward()
    optimizer.step()

    # PSNR
    train_psnr = calculate_psnr(output, target_img)
    
    # Validare
    model.eval()
    with torch.no_grad():
        val_output = model(val_input)
        val_psnr = calculate_psnr(val_output, val_target)
    
    # Rezultate
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Train PSNR: {train_psnr:.2f} dB, Val PSNR: {val_psnr:.2f} dB")

# Salvez modelul antrenat
torch.save(model.state_dict(), 'srnn_model.pth')


In [4]:
from torchvision.transforms import  ToPILImage

def test_model(model_path, test_img_path, output_path):

    model = SRNN()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    input_img, target_img = preprocess(test_img_path,scale=2)


    with torch.no_grad():
        output = model(input_img)

    output_img = ToPILImage()(output.squeeze(0))

    input_img = ToPILImage()(input_img.squeeze(0))
    input_img.save("before.jpg")

    output_img.save(output_path)
    print(f"Output saved to {output_path}")

test_model(
    model_path='models2/srnn_model2_scale2.pth', 
    test_img_path='dataset/test/b.jpg', 
    output_path='after.jpg'     
)

RuntimeError: Error(s) in loading state_dict for SRNN:
	Missing key(s) in state_dict: "layer1.weight", "layer1.bias", "layer2.weight", "layer2.bias", "layer3.weight", "layer3.bias". 
	Unexpected key(s) in state_dict: "input_conv.weight", "input_conv.bias", "residual_blocks.0.conv1.weight", "residual_blocks.0.conv1.bias", "residual_blocks.0.bn1.weight", "residual_blocks.0.bn1.bias", "residual_blocks.0.bn1.running_mean", "residual_blocks.0.bn1.running_var", "residual_blocks.0.bn1.num_batches_tracked", "residual_blocks.0.conv2.weight", "residual_blocks.0.conv2.bias", "residual_blocks.0.bn2.weight", "residual_blocks.0.bn2.bias", "residual_blocks.0.bn2.running_mean", "residual_blocks.0.bn2.running_var", "residual_blocks.0.bn2.num_batches_tracked", "residual_blocks.1.conv1.weight", "residual_blocks.1.conv1.bias", "residual_blocks.1.bn1.weight", "residual_blocks.1.bn1.bias", "residual_blocks.1.bn1.running_mean", "residual_blocks.1.bn1.running_var", "residual_blocks.1.bn1.num_batches_tracked", "residual_blocks.1.conv2.weight", "residual_blocks.1.conv2.bias", "residual_blocks.1.bn2.weight", "residual_blocks.1.bn2.bias", "residual_blocks.1.bn2.running_mean", "residual_blocks.1.bn2.running_var", "residual_blocks.1.bn2.num_batches_tracked", "residual_blocks.2.conv1.weight", "residual_blocks.2.conv1.bias", "residual_blocks.2.bn1.weight", "residual_blocks.2.bn1.bias", "residual_blocks.2.bn1.running_mean", "residual_blocks.2.bn1.running_var", "residual_blocks.2.bn1.num_batches_tracked", "residual_blocks.2.conv2.weight", "residual_blocks.2.conv2.bias", "residual_blocks.2.bn2.weight", "residual_blocks.2.bn2.bias", "residual_blocks.2.bn2.running_mean", "residual_blocks.2.bn2.running_var", "residual_blocks.2.bn2.num_batches_tracked", "residual_blocks.3.conv1.weight", "residual_blocks.3.conv1.bias", "residual_blocks.3.bn1.weight", "residual_blocks.3.bn1.bias", "residual_blocks.3.bn1.running_mean", "residual_blocks.3.bn1.running_var", "residual_blocks.3.bn1.num_batches_tracked", "residual_blocks.3.conv2.weight", "residual_blocks.3.conv2.bias", "residual_blocks.3.bn2.weight", "residual_blocks.3.bn2.bias", "residual_blocks.3.bn2.running_mean", "residual_blocks.3.bn2.running_var", "residual_blocks.3.bn2.num_batches_tracked", "residual_blocks.4.conv1.weight", "residual_blocks.4.conv1.bias", "residual_blocks.4.bn1.weight", "residual_blocks.4.bn1.bias", "residual_blocks.4.bn1.running_mean", "residual_blocks.4.bn1.running_var", "residual_blocks.4.bn1.num_batches_tracked", "residual_blocks.4.conv2.weight", "residual_blocks.4.conv2.bias", "residual_blocks.4.bn2.weight", "residual_blocks.4.bn2.bias", "residual_blocks.4.bn2.running_mean", "residual_blocks.4.bn2.running_var", "residual_blocks.4.bn2.num_batches_tracked", "channel_reduction.weight", "channel_reduction.bias", "output_conv.weight", "output_conv.bias". 

In [None]:
#  de considerat :
# data augmentation
# GANs