In [11]:
from PIL import Image
import numpy as np
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
from p5a import ReducingresolutionClass
from p5c import eval_resolution

In [3]:
# Load the images
images = [Image.open(f'{i}.jpg') for i in range(1, 11)]
reducing = ReducingresolutionClass(n=2)
features, labels, end_of_each_image, low_res_images =  reducing.reduce(images)
list_features = features
features = np.array(features)
labels = np.array(labels)

# Split the dataset into training, validation, and test sets
train_features, train_labels = features[:end_of_each_image[7]], labels[:end_of_each_image[7]]
val_features, val_labels = features[end_of_each_image[7]:end_of_each_image[8]], labels[end_of_each_image[7]:end_of_each_image[8]]
test_features, test_labels = features[end_of_each_image[8]:], labels[end_of_each_image[8]:]

In [7]:
from model_resolution import MLP, device
model = MLP().to(device)

# Define the loss function and the optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

# Convert the datasets to PyTorch tensors and move them to the appropriate device
train_features = torch.tensor(train_features, dtype=torch.float32, requires_grad=True).to(device)
train_labels = torch.tensor(train_labels, dtype=torch.float32, requires_grad=True).to(device)
val_features = torch.tensor(val_features, dtype=torch.float32, requires_grad=True).to(device)
val_labels = torch.tensor(val_labels, dtype=torch.float32, requires_grad=True).to(device)



In [73]:
# Initialize lists to store the losses
train_losses = []
val_losses = []

# Train the MLP
for epoch in range(10):  # 100 epochs
    model.train()  # Set the model to training mode
    optimizer.zero_grad()  # Reset the gradients
    train_outputs = model(train_features)  # Forward pass
    loss = criterion(train_outputs, train_labels)  # Compute the loss
    loss.backward()  # Backward pass
    optimizer.step()  # Update the weights

    # Store the training loss
    train_losses.append(loss.item())

    # Print the loss for this epoch
    print(f'Epoch {epoch+1}/{100} - Training Loss: {loss.item()}')

    # Validate the MLP
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No need to track the gradients
        val_outputs = model(val_features)  # Forward pass
        val_loss = criterion(val_outputs, val_labels)  # Compute the loss

        # Store the validation loss
        val_losses.append(val_loss.item())

        if epoch % 10 == 0:
            # Print the validation loss
            print(f'Epoch {epoch+1}/{100} - Validation Loss: {val_loss.item()}')

# Plot the learning curve
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.show()


Epoch 1/100 - Loss: 11442.35546875
Epoch 2/100 - Loss: 10745.9326171875
Epoch 3/100 - Loss: 10042.49609375
Epoch 4/100 - Loss: 9347.54296875
Epoch 5/100 - Loss: 8663.595703125
Epoch 6/100 - Loss: 7966.0205078125
Epoch 7/100 - Loss: 7256.9951171875
Epoch 8/100 - Loss: 6543.2666015625
Epoch 9/100 - Loss: 5834.6640625
Epoch 10/100 - Loss: 5138.53076171875
Validation Loss: 3093.09033203125


In [25]:
# Save the model
torch.save(model.state_dict(), 'model_part_b.pth')

In [26]:
# Convert the test dataset to PyTorch tensors and move them to the appropriate device
test_features = torch.tensor(test_features, dtype=torch.float32, requires_grad=True).to(device)
test_labels = torch.tensor(test_labels, dtype=torch.float32, requires_grad=True).to(device)
# Calculate the error function value for the test dataset
model = MLP().to(device)
model.load_state_dict(torch.load('model_part_b.pth', map_location=device))
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # No need to track the gradients
    test_outputs = model(test_features)  # Forward pass
    test_loss = criterion(test_outputs, test_labels)  # Compute the loss

# Print the test loss
print(f'Test Loss: {test_loss.item()}')

# Generate high-resolution images
high_res_images = []
for i in end_of_each_image.keys():
    if i > 1:
        low_res_image_np = features[end_of_each_image[i-1]:end_of_each_image[i]]
    else:
        low_res_image_np = features[0:end_of_each_image[1]]# Convert PIL Image to numpy array
    low_res_image_tensor = torch.tensor(low_res_image_np, dtype=torch.float32, requires_grad=True).to(device)
    high_res_image = model(low_res_image_tensor).cpu().detach().numpy()
    high_res_images.append(high_res_image)


  test_features = torch.tensor(test_features, dtype=torch.float32, requires_grad=True).to(device)
  test_labels = torch.tensor(test_labels, dtype=torch.float32, requires_grad=True).to(device)


Test Loss: 48.44804000854492


In [28]:
# Compare the high-resolution images with the original images
eval_resolution(high_res_images, images)

Image 1 - SSIM: 0.06539547853300348, PSNR: 8.642003968432451
Image 2 - SSIM: 0.2675556486208231, PSNR: 11.153144572364369
Image 3 - SSIM: 0.1785107723494155, PSNR: 12.935236537359977
Image 4 - SSIM: 0.08564555237944006, PSNR: 11.029140547952522
Image 5 - SSIM: 0.07327564128957158, PSNR: 10.027084870792697
Image 6 - SSIM: 0.1612746523768042, PSNR: 10.440442849978439
Image 7 - SSIM: 0.05159811357307465, PSNR: 8.124327969466986
Image 8 - SSIM: 0.13729326014008453, PSNR: 12.784259049037427
Image 9 - SSIM: 0.03649404387448192, PSNR: 10.97273126478477
Image 10 - SSIM: 0.23285826969205778, PSNR: 10.875041496713564


In [29]:
# Convert the numpy array to integer type
high_res_images_quantized = []
for i, (high_res_img, img) in enumerate(zip(high_res_images, images)):
    high_res_img = high_res_img.reshape(img.width, img.height, 3).astype(np.uint8)
    high_res_image_ = high_res_img.astype(np.uint8)
    high_res_images_quantized.append(high_res_image_)

# Convert the numpy array to a PIL Image
high_res_images_pil = [Image.fromarray(img).convert('RGB') for img in high_res_images_quantized]

In [54]:

for i, high_res_images_pili in enumerate(high_res_images_pil):
    high_res_images_pili = high_res_images_pili.rotate(-90, expand=True)
    from PIL import ImageOps
    high_res_images_pili = ImageOps.mirror(high_res_images_pili)
    high_res_images_pili.save(f'{i + 1}_b.jpg')