# Examen Final Práctico - Daniel Crovo

In [None]:
from torch.utils.data import Dataset, DataLoader
import os 
import numpy as np
from PIL import Image
import multiprocessing
import torch 
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import transforms
from torchmetrics import PeakSignalNoiseRatio
import torch.optim as optim
import matplotlib.pyplot as plt


In [None]:
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
NOISY_PATH = '../train/'
CLEAN_PATH = '../train_cleaned/'
VAL_PATH = '../test/'
SAVE_PATH = '../preds/'
NUM_WORKERS = multiprocessing.cpu_count()
IN_CHANNELS = 1
FEATURE_SIZE = 64
KERNEL_SIZE = 3
FC1_DIM = 128
FC2_DIM = 64
PADDING = 1
BATCH_SIZE = 16
PIN_MEMORY =  True
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 200
LR = 0.0001
psnr = PeakSignalNoiseRatio().to(DEVICE)

In [None]:
class NoisyTrainData(Dataset):
    
    def __init__(self, noisy_image_dir, clean_image_dir, transform=None):
        self.noisy_image_dir = noisy_image_dir
        self.clean_image_dir = clean_image_dir
        self.transform = transform
        self.noisy_images = sorted(os.listdir(noisy_image_dir))
        self.clean_images = sorted(os.listdir(clean_image_dir))
    
    def __len__(self):
        return len(self.clean_images)
    
    def __getitem__(self, index):
        clean_img_path = os.path.join(self.clean_image_dir, self.clean_images[index])
        noisy_img_path = os.path.join(self.noisy_image_dir, self.noisy_images[index])
        clean_image = np.array(Image.open(clean_img_path).convert('L'))
        noisy_image = np.array(Image.open(noisy_img_path).convert('L'))

        
        if self.transform is not None: 
            transformations = self.transform(image = clean_image, noisy_image = noisy_image)
            clean_image = transformations['image']
            noisy_image = transformations['noisy_image']

        return clean_image, noisy_image

class NoisyValData(Dataset):
    def __init__(self, val_dir, transform=None):
        self.val_dir = val_dir
        self.transform = transform
        self.images = sorted(os.listdir(val_dir))

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.val_dir, self.images[index])
        image = np.array(Image.open(img_path).convert('L'))

        if self.transform is not None:
            transformations = self.transform(image = image)
            image = transformations['image']
        return image

In [None]:
def getLoaders(clean_dir, noisy_dir, val_dir, batch_size, 
               train_transform, val_transform, num_workers, 
               pin_memory):
    train_dataset = NoisyTrainData(noisy_image_dir=noisy_dir, clean_image_dir=clean_dir, transform= train_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)
    val_dataset = NoisyValData(val_dir=val_dir, transform = val_transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)

    return train_loader, val_loader

In [None]:
class DAE(nn.Module):
    def __init__(self,in_channels, feature_size, kernel_size, padding, fc1_dim, fc2_dim):
        super(DAE, self).__init__()
        self.inconv2 =feature_size
        self.outconv2 = int(feature_size/2)
        self.inconv3 = self.outconv2
        self.outconv3 = int(self.outconv2/2) 
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=feature_size, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(feature_size),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=self.inconv2, out_channels=self.outconv2, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(self.outconv2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=self.inconv3, out_channels=self.outconv3, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True), 
            nn.MaxPool2d(2,2)
        )

        # Fully connected

        # self.flatten = nn.Flatten(start_dim=1)
        # self.encoder_linear = nn.Sequential(
        #     nn.Linear((self.outconv3*((IMAGE_HEIGHT//2)*(IMAGE_WIDTH//2))), fc1_dim), 
        #     nn.ReLU(inplace=True), 
        #     nn.Linear(fc1_dim, fc2_dim),
        #     nn.ReLU(True)
        # )

        # self.decoder_linear = nn.Sequential(
        #     nn.Linear(fc2_dim, fc1_dim), 
        #     nn.ReLU(True),
        #     nn.Linear(fc1_dim, self.outconv3*IMAGE_HEIGHT*IMAGE_WIDTH),
        #     nn.ReLU(True)
        # )
        # self.unflatten = nn.Unflatten(dim=1, unflattened_size=(self.outconv3, IMAGE_HEIGHT, IMAGE_WIDTH))

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(self.outconv3, self.inconv3, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(self.inconv3), 
            nn.ReLU(True),
            nn.ConvTranspose2d(self.outconv2, self.inconv2, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(self.inconv2),
            nn.ReLU(True),
            nn.ConvTranspose2d(feature_size, in_channels, kernel_size=kernel_size, padding=padding, stride =2, output_padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        #print(x.shape)
        # #print(x.shape)
        # x = self.flatten(x)
        # #print(x.shape)

        # x = self.encoder_linear(x)
        # #print(x.shape)
        # x = self.decoder_linear(x)
        # #print(x.shape)

        # x = self.unflatten(x)
        x = self.decoder(x)
        #print(x.shape)

        x = torch.sigmoid(x)
        return x
    

In [None]:
train_transforms = A.Compose(
    [   A.Resize(height= IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit = 50, p = 1.0), 
        A.HorizontalFlip(p = 0.3), 
        A.VerticalFlip(p = 0.1), 
        A.Normalize(
                    mean = [0.0],
                    std = [1.0], 
                    max_pixel_value = 255.0
                    ),
     ToTensorV2(),],
     additional_targets={'noisy_image': 'image' }
)
val_transforms = A.Compose(
    [A.Resize(height= IMAGE_HEIGHT, width=IMAGE_WIDTH),
     A.Normalize(
                mean = [0.0],
                std = [1.0], 
                max_pixel_value = 255.0
                ),
     ToTensorV2(),],
)

In [None]:
train_loader, val_loader = getLoaders(CLEAN_PATH, NOISY_PATH, VAL_PATH, BATCH_SIZE,train_transforms, val_transforms, num_workers=NUM_WORKERS, pin_memory= PIN_MEMORY)

In [None]:
model = DAE(IN_CHANNELS, FEATURE_SIZE, KERNEL_SIZE, PADDING, FC1_DIM, FC2_DIM).to(DEVICE)
print(model)

In [None]:
def train_func(train_loader, model, optimizer, loss_fn, scaler, device, epoch):
    p_bar = tqdm(train_loader)
    loss_list = []
    model.train()
    for batch_idx, (clean_img, noisy_img) in enumerate(p_bar):
        clean_img = clean_img.float()
        noisy_img = noisy_img.float()
        clean_img, noisy_img = clean_img.to(device = device), noisy_img.to(device=device)
       
        #Forward pass
        with torch.cuda.amp.autocast():
            preds = model(noisy_img)
            preds = preds.float()
            loss = loss_fn(preds, clean_img)


        #Backward Pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss_list.append(loss.item())
        torch.cuda.empty_cache()
        #print('Epoch: ', epoch)
        p_bar.set_postfix(loss=loss.item())
    return loss_list

In [None]:
def performance(val_loader, model, device, folder):
    model.eval()
    psnr_list = []
    with torch.no_grad():
        for val_img in val_loader:
            val_img = val_img.float()

            val_img = val_img.to(device = device)
            preds = model(val_img)
            #loss = loss_fn(preds, clean _img)
            psnr_value = psnr(preds, val_img)

            
            psnr_list.append(psnr(preds, val_img))
            ##print(f'signal to noise ratio: {psnr_value}')

    
    return psnr_list


In [None]:
def save_preds_as_imgs(loader, model, device, folder ): 
    model.eval()
    for idx, (x) in enumerate(loader):
        x = x.to(device = device)
        with torch.no_grad(): # deshabilitar el cálculo y almacenamiento de gradientes en el grafo computacional de PyTorch
            x=x.float()
            preds = (model(x))
        torchvision.utils.save_image(preds, f'{folder}/y_cleaned_{idx}.png') # #almacenamiento de imagenes procesadas
        torchvision.utils.save_image(x.to(torch.float32), f'{folder}/y_noisy_{idx}.png') # almacenamiento de mimagenes con ruido
    
    model.train()

In [None]:

optimizer = optim.Adam(model.parameters(), lr=LR) 
loss_fn = nn.MSELoss()
scaler = torch.cuda.amp.GradScaler()
eval =[]
train_loss = []
for epoch in range(EPOCHS):
    print('Epoch: ', epoch)
    train_loss.append(train_func(train_loader, model, optimizer, loss_fn, scaler, DEVICE, epoch))
    eval.append(performance(val_loader, model, DEVICE, SAVE_PATH))
    save_preds_as_imgs(val_loader, model, DEVICE, SAVE_PATH) 

In [None]:
eval[-1]

In [None]:
plt.plot(train_loss)

In [None]:
import matplotlib.pyplot as plt
dataiter = iter(train_loader)
images, noisy = dataiter._next_data()
images = images.numpy()

# get one image from the batch
img = np.squeeze(images[5])
nois = np.squeeze(noisy[5])
fig = plt.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
fig = plt.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(nois, cmap='gray')
print(noisy)