In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import numpy as np
import pandas as pd
import os, math, sys
import glob, itertools
import argparse, random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19, inception_v3
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid

import plotly
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split

random.seed(42)
import warnings
warnings.filterwarnings("ignore")
# load pretrained models


In [None]:
# load pretrained models
load_pretrained_models = False

# nombre d'époques de l'apprentissage
n_epochs = 4

# lien de la base de données
dataset_path = "./drive/MyDrive/images/Test/"

# taille des batches
batch_size = 16

# adam : taux d'apprentissage
#Adam est un algorithme d'optimisation pour la descente de gradient stochastique pour l'entrainement de modèles d'apprentissage en profondeur.
lr = 0.00008

# adam : décroissance de la quantité de mouvement de premier ordre du gradient
b1 = 0.5

# adam : décroissance de la quantité de mouvement de second ordre du gradient
b2 = 0.999

# époque à partir de laquelle commencer la décroissance lr
decay_epoch = 100

# number of cpu threads to use during batch generation
n_cpu = 8

# haute résolution. hauteur de l'image
hr_height = 256

# haute résolution. Largeur de l'image
hr_width = 256

# nombre de canaux d'images
channels = 3

#os.makedirs("images", exist_ok=True)
#os.makedirs("saved_models", exist_ok=True)

cuda = torch.cuda.is_available()
hr_shape = (hr_height, hr_width)

### Define Dataset Class

Tous les modèles pré-entraînés attendent des images d'entrée normalisées de la même manière, c'est-à-dire des mini-batch d'images RVB à 3 canaux de forme (3 x H x L), où H et W devraient être d'au moins 224. Les images doivent être chargé dans une plage de [0, 1] puis normalisé en utilisant la moyenne = [0,485, 0,456, 0,406] et std = [0,229, 0,224, 0,225].

In [None]:
# Paramètres de normalisation pour les modèles PyTorch pré-entraînés
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


class ImageDataset(Dataset):
    def __init__(self, files, hr_shape):
        hr_height, hr_width = hr_shape
        # Transforme pour les images basse résolution et les images haute résolution
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.files = files
    
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

### Get Train/Test Dataloaders

Pour empêcher le chargement de données de bloquer l'entrainemnt, nous pouvons créer des « workers » qui chargent les données de manière asynchrone. Un moyen simple de le faire est de fournir à chaque travailleur une file d'attente d'indices pour cette charge de travail, et une file d'attente de sortie où le travailleur peut placer les données chargées. Tout ce que le worker a à faire est de vérifier à plusieurs reprises sa file d'attente d'index, et de charger les données si la file d'attente n'est pas vide

In [None]:
a=[]

In [None]:
while(len(a)==0):
  a = sorted(glob.glob(dataset_path + "/*.*"))

In [None]:
len(a)

159618

In [None]:
dataset =a #[:100000]

In [None]:
len(dataset)

100000

In [None]:
train_paths, test_paths = train_test_split(dataset, test_size=0.02, random_state=42)
train_dataloader = DataLoader(ImageDataset(train_paths, hr_shape=hr_shape), batch_size=batch_size, shuffle=True, num_workers=n_cpu)
test_dataloader = DataLoader(ImageDataset(test_paths, hr_shape=hr_shape), batch_size=int(batch_size*0.75), shuffle=True, num_workers=n_cpu)

### Define Model Classes

https://www.cl.cam.ac.uk/research/rainbow/projects/mdf/

Using VGG19 model to Extract *Features*

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)


Using Inception model to Extract *Features*

In [None]:
# class MyInceptionFeatureExtractor(nn.Module):
#     def __init__(self, inception, transform_input=False):
#         super(MyInceptionFeatureExtractor, self).__init__()
#         self.transform_input = transform_input
#         self.Conv2d_1a_3x3 = inception.Conv2d_1a_3x3
#         self.Conv2d_2a_3x3 = inception.Conv2d_2a_3x3
#         self.Conv2d_2b_3x3 = inception.Conv2d_2b_3x3
#         self.Conv2d_3b_1x1 = inception.Conv2d_3b_1x1
#         self.Conv2d_4a_3x3 = inception.Conv2d_4a_3x3
#         self.Mixed_5b = inception.Mixed_5b
#         # stop where you want, copy paste from the model def

#     def forward(self, x):
#         if self.transform_input:
#             x = x.clone()
#             x[0] = x[0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
#             x[1] = x[1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
#             x[2] = x[2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
#         # 299 x 299 x 3
#         x = self.Conv2d_1a_3x3(x)
#         # 149 x 149 x 32
#         x = self.Conv2d_2a_3x3(x)
#         # 147 x 147 x 32
#         x = self.Conv2d_2b_3x3(x)
#         # 147 x 147 x 64
#         x = F.max_pool2d(x, kernel_size=3, stride=2)
#         # 73 x 73 x 64
#         x = self.Conv2d_3b_1x1(x)
#         # 73 x 73 x 80
#         x = self.Conv2d_4a_3x3(x)
#         # 71 x 71 x 192
#         x = F.max_pool2d(x, kernel_size=3, stride=2)
#         # 35 x 35 x 192
#         x = self.Mixed_5b(x)
#         # copy paste from model definition, just stopping where you want
#         return x




In [None]:
import torchvision

In [None]:
# class FeatureExtractor(nn.Module):
#     def __init__(self):
#         super(FeatureExtractor, self).__init__()
#         inception_model = my_inception
#         self.feature_extractor = nn.Sequential(*list(inception_model.children()))

#     def forward(self, img):
#         return self.feature_extractor(img)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
        )

    def forward(self, x):
        return x + self.conv_block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(GeneratorResNet, self).__init__()

        # First layer
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())

        # Residual blocks
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)

        # Second conv layer post residual blocks
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))

        # Upsampling layers
        upsampling = []
        for out_features in range(2):
            upsampling += [
                # nn.Upsample(scale_factor=2),
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)

        # Final output layer
        self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

### Train Super Resolution GAN (SRGAN)

In [None]:
# Initialize generator and discriminator
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(channels, *hr_shape))
feature_extractor = FeatureExtractor()

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    feature_extractor = feature_extractor.cuda()
    criterion_GAN = criterion_GAN.cuda()
    criterion_content = criterion_content.cuda()

# Load pretrained models
if load_pretrained_models:
    generator.load_state_dict(torch.load("./drive/MyDrive/saved_models/generator.pth"))
    discriminator.load_state_dict(torch.load("./drive/MyDrive/saved_models/discriminator.pth"))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




Compute psnr and SSIM (quantify reconstruction quality for images)

In [None]:
psnr_values = {}
ssim_values = {}
ms_ssim_values = {}
epoch_iter = {}
mse_values = {}

In [None]:
def compute_psnr(epoch, i, original_image, generated_image):
    rand = random.randint(0,len(original_image)-1)
    mse = torch.mean((original_image[rand] - generated_image[rand]) ** 2)
    psnr = 20 * torch.log10(255.0 / torch.sqrt(mse))
    #psnr_values.append(psnr.item())
    psnr_values[i]=psnr.item()
    mse_values[i]= mse.cpu().detach().numpy()
    epoch_iter[epoch]=i
    save_image(generated_image, f"./drive/MyDrive/vgg/images/generated{i}.png", normalize=False)
    #return psnr.item()

In [None]:
pip install pytorch-msssim

Collecting pytorch-msssim
  Downloading https://files.pythonhosted.org/packages/9d/d3/3cb0f397232cf79e1762323c3a8862e39ad53eca0bb5f6be9ccc8e7c070e/pytorch_msssim-0.2.1-py3-none-any.whl
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-0.2.1


In [None]:
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM


# X: (N,3,H,W) a batch of non-negative RGB images (0~255)
# Y: (N,3,H,W)

# calculate ssim & ms-ssim for each image
def compute_ssim(i, original, generated):
    X = original
    Y = generated

    ssim_val = ssim(X, Y, data_range=255, size_average=False)  # return (N,)
    ms_ssim_val = ms_ssim(X, Y, data_range=255, size_average=False)  # (N,)

    # set 'size_average=True' to get a scalar value as loss. see tests/tests_loss.py for more details
    ssim_loss = 1 - ssim(X, Y, data_range=255, size_average=True)  # return a scalar
    ms_ssim_loss = 1 - ms_ssim(X, Y, data_range=255, size_average=True)

    # reuse the gaussian kernel with SSIM & MS_SSIM.
    ssim_module = SSIM(data_range=255, size_average=True, channel=3)
    ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=3)

    ssim_loss = 1 - ssim_module(X, Y)
    ms_ssim_loss = 1 - ms_ssim_module(X, Y)
    # X: (N,3,H,W) a batch of normalized images (-1 ~ 1)
    # Y: (N,3,H,W)
    X = (X + 1) / 2  # [-1, 1] => [0, 1]
    Y = (Y + 1) / 2
    ms_ssim_val = ms_ssim(X, Y, data_range=1, size_average=False)  # (N,)
    ssim_values[i] = ssim_loss.item()
    ms_ssim_values[i]= ms_ssim_val.cpu().detach().numpy()
    #return ssim_loss.item()




In [None]:
train_gen_losses, train_disc_losses, train_counter = [], [], []
test_gen_losses, test_disc_losses = [], []
test_counter = [idx*len(train_dataloader.dataset) for idx in range(1, n_epochs+1)]
i = 0
for epoch in range(n_epochs):

    ### Training
    gen_loss, disc_loss = 0, 0
    tqdm_bar = tqdm(train_dataloader, desc=f'Training Epoch {epoch} ', total=int(len(train_dataloader)))
    for batch_idx, imgs in enumerate(tqdm_bar):
        generator.train(); discriminator.train()
        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        
        ### Train Generator
        optimizer_G.zero_grad()
        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)
        # Adversarial loss
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())
        # Total loss
        loss_G = loss_content + 1e-3 * loss_GAN
        loss_G.backward()
        optimizer_G.step()

        ### Train Discriminator
        optimizer_D.zero_grad()
        # Loss of real and fake images
        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
        # Total loss
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizer_D.step()

        gen_loss += loss_G.item()
        train_gen_losses.append(loss_G.item())
        disc_loss += loss_D.item()
        train_disc_losses.append(loss_D.item())
        train_counter.append(batch_idx*batch_size + imgs_lr.size(0) + epoch*len(train_dataloader.dataset))
        tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))

    # Testing
    gen_loss, disc_loss = 0, 0
    tqdm_bar = tqdm(test_dataloader, desc=f'Testing Epoch {epoch} ', total=int(len(test_dataloader)))
    for batch_idx, imgs in enumerate(tqdm_bar):
        generator.eval(); discriminator.eval()
        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        
        ### Eval Generator
        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)
        # Adversarial loss
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())
        # Total loss
        loss_G = loss_content + 1e-3 * loss_GAN

        ### Eval Discriminator
        # Loss of real and fake images
        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        gen_loss += loss_G.item()
        disc_loss += loss_D.item()
        compute_psnr(epoch, i,imgs_hr, gen_hr)
        compute_ssim(i, imgs_hr, gen_hr)
        i+=1
        tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))
        
        # Save image grid with upsampled inputs and SRGAN outputs
        if random.uniform(0,1)<0.1:
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
            save_image(img_grid, f"./drive/MyDrive/vgg/{batch_idx}.png", normalize=False)

    test_gen_losses.append(gen_loss/len(test_dataloader))
    test_disc_losses.append(disc_loss/len(test_dataloader))
    
    # Save model checkpoints
    if np.argmin(test_gen_losses) == len(test_gen_losses)-1:
        torch.save(generator.state_dict(), "./drive/MyDrive/vgg/generator.pth")
        torch.save(discriminator.state_dict(), "./drive/MyDrive/vgg/discriminator.pth")
        

HBox(children=(FloatProgress(value=0.0, description='Training Epoch 0 ', max=2932.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Testing Epoch 0 ', max=80.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Training Epoch 1 ', max=2932.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Testing Epoch 1 ', max=80.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Training Epoch 2 ', max=2932.0, style=ProgressStyle(descr…

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_gen_losses, mode='lines', name='Train Generator Loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_gen_losses, marker_symbol='star-diamond', 
                         marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test Generator Loss'))
fig.update_layout(
    width=1000,
    height=500,
    title="Train vs. Test Generator Loss",
    xaxis_title="Number of training examples seen",
    yaxis_title="Adversarial + Content Loss"),
fig.show()

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_disc_losses, mode='lines', name='Train Discriminator Loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_disc_losses, marker_symbol='star-diamond', 
                         marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test Discriminator Loss'))
fig.update_layout(
    width=1000,
    height=500,
    title="Train vs. Test Discriminator Loss",
    xaxis_title="Number of training examples seen",
    yaxis_title="Adversarial Loss"),
fig.show()

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(list(psnr_values.keys()),list(psnr_values.values()))
plt.title('PSNR ')
plt.ylabel('psnr')
plt.xlabel('iteration')
plt.legend(['psnr'], loc='upper left')
#plt.savefig('ID modelo: model accuracy.png')
#plt.clf()
plt.show()

In [None]:
plt.plot(list(ssim_values.keys()),list(ssim_values.values()))
plt.title('SSIM')
plt.ylabel('ssim')
plt.xlabel('iteration')
plt.legend(['ssim'], loc='upper left')
#plt.savefig('ID modelo: model accuracy.png')
#plt.clf()
plt.show()

In [None]:
plt.plot(list(mse_values.keys()),list(mse_values.values()))
plt.title('MSE')
plt.ylabel('mse')
plt.xlabel('iteration')
plt.legend(['mse'], loc='upper left')
#plt.savefig('ID modelo: model accuracy.png')
#plt.clf()
plt.show()

In [None]:
plt.plot(list(ms_ssim_values.keys()),list(ms_ssim_values.values()))
plt.title('MSSSIM')
plt.ylabel('msssim')
plt.xlabel('iteration')
plt.legend(['train', 'test'], loc='upper left')
#plt.savefig('ID modelo: model accuracy.png')
#plt.clf()
plt.show()

In [None]:
psnr_values

In [None]:
ssim_values

In [None]:
ms_ssim_values

In [None]:
epoch_iter

In [None]:
mse_values