ToDo
 * Generate plot of psnr to epoch
 * Calculate ssim and plot over epoch
 * Save inferred images for last epoch for each network architecture
 * Try SRCNN(9-5-5): https://github.com/kunal-visoulia/Image-Restoration-using-SRCNN/blob/master/Image%20Super%20Resolution.
 
 Litterature
 * Learning to Generate Images With Perceptual Similarity Metrics, https://arxiv.org/pdf/1511.06409.pdf
 * Image Super-Resolution Using Deep Convolutional Networks, https://arxiv.org/pdf/1501.00092.pdf
 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchmetrics.image import PeakSignalNoiseRatio
import matplotlib.pyplot as plt
import os, re
from PIL import Image as PILImage

In [None]:
# Verify GPU present
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")
for gpu_id in range(num_gpus):
    print(f"   GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Set the seed value
seed_value = 18
# PyTorch
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)  # For CUDA
# Ensure that DataLoader shuffle order is reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

precision = torch.float32
dataset_images_path = "C:/_sw/eb_python/deep_learning/autoencoder/codec_face/dataset/eric"
BATCH_SIZE=16

property_dataset = {
    "Face": {
        "keyword": "face",
        "Hin": 384, #768, #192, #800,
        "Vin": 448, #896, #224, #1000,
        "Cin": 3
    }
}

In [None]:
#Convert the string into a list of strings and integers that is naturally sortable.
def natural_sort_key(s):
    return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]

def ImportDataset_images(path, device, precision, parameters):

    # Filter for files that have the .png extension and include the specific word
    all_files = os.listdir(path)
    filtered_files = [file for file in all_files if file.endswith('.png') and parameters["keyword"] in file]
    filtered_files.sort(key=natural_sort_key)   # Sort using the natural sort key
    print("   " + str(len(filtered_files)) + " files found with keyword: " + parameters["keyword"])

    images_list = []
    transform = transforms.Compose([
            transforms.Resize((parameters["Vin"], parameters["Hin"])),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    
    for index, file in enumerate(filtered_files):
        print("\r   Image #" + str(index + 1) + "/" + str(len(filtered_files)), end="")
        image_path = os.path.join(path, f'{parameters["keyword"]}_picture_{index}.png')
        image = PILImage.open(image_path)
        images_list.append(transform(image))
        if index == 3000: #3001: #1001: #32 * 6 - 1:
            break

    data = torch.stack(images_list)
    data = data.to(device)
    data = data.to(precision)
    
    return data

In [None]:
datasets_img_face = ImportDataset_images(dataset_images_path, "cpu", precision, property_dataset["Face"])
data_loader = torch.utils.data.DataLoader(datasets_img_face, batch_size=BATCH_SIZE, shuffle=True)

dataiter = iter(data_loader)
images = next(dataiter)
print("Min: " + str(torch.min(images))  +" - Max: " + str(torch.max(images)))

datasets_img_face.shape

In [None]:
# C(ch16 k3 s1 p1)R C(ch32 k3 s1 p1)BN:R C(ch64 k3 s1 p1)BN:R C(ch128 k3 s1 p1)BN:R C(ch256 k3 s1 p1)BN:R C(ch512 k3 s1 p1)BN:R C(ch1024 k5)BN:R
class Autoencoder_Conv(nn.Module):
    def __init__(self):
        super().__init__()
        
        # N, 3, 896, 768
        self.enc_conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1, bias = False)       # -> N, 16, 448, 384
        self.enc_bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1, bias = False)      # -> N, 32, 224, 192
        self.enc_bn2 = nn.BatchNorm2d(32)
        self.enc_conv3 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias = False)       # -> N, 64, 112, 96
        self.enc_bn3 = nn.BatchNorm2d(64)
        self.enc_conv4 = nn.Conv2d(64, 128, 3, stride=1, padding=1, bias = False)      # -> N, 128, 56, 48
        self.enc_bn4 = nn.BatchNorm2d(128)
        self.enc_conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1, bias = False)      # -> N, 256, 28, 24
        self.enc_bn5 = nn.BatchNorm2d(256)
        self.enc_conv6 = nn.Conv2d(256, 512, 3, stride=1, padding=1, bias = False)     # -> N, 512, 14, 12
        self.enc_bn6 = nn.BatchNorm2d(512)
        self.enc_conv7 = nn.Conv2d(512, 1024, 5, bias = False)    # -> N, 1024, 7, 6
        self.enc_bn7 = nn.BatchNorm2d(1024)

        self.fc1 = nn.Linear(1024 * 3 * 2, 100)
        self.fc2 = nn.Linear(100, 1024 * 3 * 2)

        # N, 2048, 3, 2
        self.dec_convtrans1 = nn.ConvTranspose2d(1024, 512, 5)       # -> N, 1024, 7, 6
        self.dec_convtrans3 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)    # N, 256, 28, 14
        self.dec_convtrans4 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)    # N, 128, 56, 48
        self.dec_convtrans5 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)    # N, 64, 112, 96
        self.dec_convtrans6 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)    # N, 32, 224, 192
        self.dec_convtrans7 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)    # N, 16, 448, 384
        self.dec_convtrans8 = nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1)    # N, 3, 896, 768


    def forward(self, x):
        # Encoder
        x = self.maxpool(self.relu(self.enc_bn1(self.enc_conv1(x))))
        x = self.maxpool(self.relu(self.enc_bn2(self.enc_conv2(x))))
        x = self.maxpool(self.relu(self.enc_bn3(self.enc_conv3(x))))
        x = self.maxpool(self.relu(self.enc_bn4(self.enc_conv4(x))))
        x = self.maxpool(self.relu(self.enc_bn5(self.enc_conv5(x))))
        x = self.maxpool(self.relu(self.enc_bn6(self.enc_conv6(x))))
        x = self.relu(self.enc_bn7(self.enc_conv7(x)))

        # Bottleneck
        x = x.view(-1, 1024 * 3 * 2)     # Reshape input to [batch_size, 2048 * 3 * 2]
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(-1, 1024, 3, 2)     # Reshape input to [batch_size, 2048, 3, 2]

        # Decoder
        x = F.relu(self.dec_convtrans1(x))
        x = F.relu(self.dec_convtrans3(x))
        x = F.relu(self.dec_convtrans4(x))
        x = F.relu(self.dec_convtrans5(x))
        x = F.relu(self.dec_convtrans6(x))
        x = F.relu(self.dec_convtrans7(x))
        #x = F.sigmoid(self.dec_convtrans8(x))
        x = F.tanh(self.dec_convtrans8(x))
        
        return x    

In [None]:
# C(ch16 k3 s1 p1)R C(ch32 k3 s1 p1)BN:R C(ch64 k3 s1 p1)BN:R C(ch128 k3 s1 p1)BN:R C(ch256 k3 s1 p1)BN:R C(ch512 k3 s1 p1)BN:R C(ch1024 k3 s1 p1)BN:R C(ch2048 k5)BN:R
class Autoencoder_Conv(nn.Module):
    def __init__(self):
        super().__init__()
        
        # N, 3, 896, 768
        self.enc_conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1, bias = False)       # -> N, 16, 448, 384
        self.enc_bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1, bias = False)      # -> N, 32, 224, 192
        self.enc_bn2 = nn.BatchNorm2d(32)
        self.enc_conv3 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias = False)       # -> N, 64, 112, 96
        self.enc_bn3 = nn.BatchNorm2d(64)
        self.enc_conv4 = nn.Conv2d(64, 128, 3, stride=1, padding=1, bias = False)      # -> N, 128, 56, 48
        self.enc_bn4 = nn.BatchNorm2d(128)
        self.enc_conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1, bias = False)      # -> N, 256, 28, 24
        self.enc_bn5 = nn.BatchNorm2d(256)
        self.enc_conv6 = nn.Conv2d(256, 512, 3, stride=1, padding=1, bias = False)     # -> N, 512, 14, 12
        self.enc_bn6 = nn.BatchNorm2d(512)
        self.enc_conv7 = nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias = False)    # -> N, 1024, 7, 6
        self.enc_bn7 = nn.BatchNorm2d(1024)
        self.enc_conv8 = nn.Conv2d(1024, 2048, 5)     # -> N, 2048, 3, 2
        self.enc_bn8 = nn.BatchNorm2d(2048)

        self.fc1 = nn.Linear(2048 * 3 * 2, 100)
        self.fc2 = nn.Linear(100, 2048 * 3 * 2)

        # N, 2048, 3, 2
        self.dec_convtrans1 = nn.ConvTranspose2d(2048, 1024, 5)       # -> N, 1024, 7, 6
        self.dec_convtrans2 = nn.ConvTranspose2d(1024, 512, 3, stride=2, padding=1, output_padding=1)    # N, 512, 14, 12
        self.dec_convtrans3 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)    # N, 256, 28, 14
        self.dec_convtrans4 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)    # N, 128, 56, 48
        self.dec_convtrans5 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)    # N, 64, 112, 96
        self.dec_convtrans6 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)    # N, 32, 224, 192
        self.dec_convtrans7 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)    # N, 16, 448, 384
        self.dec_convtrans8 = nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1)    # N, 3, 896, 768


    def forward(self, x):
        # Encoder
        x = self.maxpool(self.relu(self.enc_bn1(self.enc_conv1(x))))
        x = self.maxpool(self.relu(self.enc_bn2(self.enc_conv2(x))))
        x = self.maxpool(self.relu(self.enc_bn3(self.enc_conv3(x))))
        x = self.maxpool(self.relu(self.enc_bn4(self.enc_conv4(x))))
        x = self.maxpool(self.relu(self.enc_bn5(self.enc_conv5(x))))
        x = self.maxpool(self.relu(self.enc_bn6(self.enc_conv6(x))))
        x = self.maxpool(self.relu(self.enc_bn7(self.enc_conv7(x))))
        x = self.relu(self.enc_bn8(self.enc_conv8(x)))

        # Bottleneck
        x = x.view(-1, 2048 * 3 * 2)     # Reshape input to [batch_size, 2048 * 3 * 2]
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(-1, 2048, 3, 2)     # Reshape input to [batch_size, 2048, 3, 2]

        # Decoder
        x = F.relu(self.dec_convtrans1(x))
        x = F.relu(self.dec_convtrans2(x))
        x = F.relu(self.dec_convtrans3(x))
        x = F.relu(self.dec_convtrans4(x))
        x = F.relu(self.dec_convtrans5(x))
        x = F.relu(self.dec_convtrans6(x))
        x = F.relu(self.dec_convtrans7(x))
        #x = F.sigmoid(self.dec_convtrans8(x))
        x = F.tanh(self.dec_convtrans8(x))
        
        return x

In [None]:
model = Autoencoder_Conv().to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
                             lr=1e-3, 
                             weight_decay=1e-5)

In [None]:
### Train model ###

num_epochs = 100
dataiter = iter(data_loader)
val_imgs = next(dataiter)       # Reuse same images for validation
val_imgs = val_imgs.to(device)
outputs = []

psnr_metric = PeakSignalNoiseRatio()    # Initialize the PSNR metric object

for epoch in range(num_epochs):
    batch = 1
    for (train_imgs) in data_loader:
        model.train()
        optimizer.zero_grad()

        train_imgs = train_imgs.to(device)
        train_out = model(train_imgs)
        loss = criterion(train_out, train_imgs)
        loss.backward()
        optimizer.step()

        if batch%1 == 0:
            model.eval()
            with torch.no_grad():  # Ensure no gradients are computed for validation
                val_out = model(val_imgs)
                psnr_value = psnr_metric(val_imgs.cpu(), val_out.cpu())   # Compute PSNR
                print(f'Epoch:{epoch+1}, Batch:{batch+1}, Loss:{loss.item():.4f}, PSNR:{psnr_value:.2f}dB')
                #outputs.append((epoch, batch, val_imgs.cpu(), val_out.detach().cpu(), psnr_value))      # detach() removes the tensor from the computation graph
                outputs.append((epoch, batch, psnr_value, loss.item()))
        batch += 1

num_batches = batch - 1

In [None]:
### Generage image of ground truth and inferred data for each batch of each epoch ###

lst_number = [0, 1, 2, 3, 4]    # Index for the 5 images to be used for inferrence

for idx_epoch in range(0, num_epochs, 1):
    for idx_batch in range(0, num_batches//25, 1):
        plt.figure(figsize=(6, 5))
        epoch = outputs[idx_epoch*num_batches//25+idx_batch][0]
        batch = outputs[idx_epoch*num_batches//25+idx_batch][1]
        imgs = outputs[idx_epoch*num_batches//25+idx_batch][2].detach().numpy()
        recon = outputs[idx_epoch*num_batches//25+idx_batch][3].detach().numpy()

        for i in range(5):
            plt.subplot(2, 5, i+1)
            img = imgs[lst_number[i]].transpose(1, 2, 0)    # Reordering dimension of tensor from (channel, width, height) to (width, height, channel)
            img = (img - img.min()) / (img.max() - img.min())   # Normalize values from 0 to 1
            #print("GT: max: " + str(img.max()) + " - min: " + str(img.min()))
            plt.imshow(img)
            plt.axis('off')

        for i in range(5):
            plt.subplot(2, 5, 5+i+1) # row_length + i + 1
            img = recon[lst_number[i]].transpose(1, 2, 0)    # Reordering dimension of tensor from (channel, width, height) to (width, height, channel)
            img = (img - img.min()) / (img.max() - img.min())   # Normalize values from 0 to 1
            #print("OUT: max: " + str(img.max()) + " - min: " + str(img.min()))
            plt.imshow(img)
            plt.axis('off')
        
        plt.suptitle('Epoch #' + str(epoch) + ' - Batch #' + str(batch) + "\nC(ch16 k3 s1 p1)R C(ch32 k3 s1 p1)R C(ch64 k3 s1 p1)R C(ch128 k3 s1 p1)R C(ch256 k3 s1 p1)R C(ch512 k3 s1 p1 + Bottleneck)", fontsize=6)
        #plt.suptitle('Epoch #' + str(epoch) + ' - Batch #' + str(batch) + "\nC(ch16 k3 s1 p1)R C(ch32 k3 s1 p1)R C(ch64 k3 s1 p1)R C(ch128 k3 s1 p1)R C(ch256 k5)", fontsize=16)
        plt.savefig('results/_temp/face_autoencoder_epoch' + str(epoch) + '_batch' + str(batch).zfill(3) + '.png', format='png', dpi=300)
        plt.close()  # Close the figure to avoid displaying it

In [None]:
### Generate video off images ###

import cv2
import os
import glob

# Directory containing images
img_dir = "results/C(ch16 k3 s1 p1)R C(ch32 k3 s1 p1)R C(ch64 k3 s1 p1)R C(ch128 k3 s1 p1)R C(ch256 k3 s1 p1)R C(ch512 k5)"  # Enter Directory of all images
data_path = os.path.join(img_dir, '*g')  # Assuming images are in JPG or PNG format, adjust the wildcard as needed
files = sorted(glob.glob(data_path))  # Sort files by filename

# Frame properties
RATIO = 1
frame = cv2.imread(files[0])
height, width, layers = frame.shape
size = (width//RATIO, height//RATIO)

# Video properties
out = cv2.VideoWriter('results/C(ch16 k3 s1 p1)R C(ch32 k3 s1 p1)R C(ch64 k3 s1 p1)R C(ch128 k3 s1 p1)R C(ch256 k3 s1 p1)R C(ch512 k5)/face_autoencoder.avi', cv2.VideoWriter_fourcc(*'DIVX'), 10, size)  # 1 is the FPS, adjust as needed

for f in files:
    img = cv2.imread(f)
    img = cv2.resize(img, size)
    out.write(img)

out.release()