# Définitions des modèles

In [1]:
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
from collections import Counter, defaultdict
from torchvision.utils import save_image
import torch.optim as optim


In [2]:
class Block(nn.Module):
  def __init__(self, nb_in_layers, nb_out_layers, stride):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(nb_in_layers, nb_out_layers, 4, stride, padding=1, bias=True, padding_mode="reflect"),
        nn.InstanceNorm2d(nb_out_layers),
        nn.LeakyReLU(0.2),
    )

  def forward(self, x):
    return self.conv(x)

In [3]:
class Discriminator(nn.Module):
  def __init__(self, nb_in_layers=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(nb_in_layers, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect",),
        nn.LeakyReLU(0.2),
        )
    self.layers = nn.ModuleList()
    in_channels = features[0]
    for idx, feature in enumerate(features[1:]):
      if feature == features[-1]:
        self.layers.add_module(f'Block{idx}_{feature}', Block(in_channels, feature, stride=1))
      else:
        self.layers.add_module(f'Block{idx}_{feature}', Block(in_channels, feature, stride=2))
      in_channels = feature
    self.layers.add_module('OutputMapping', nn.Conv2d(in_channels, out_channels=1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))

  def forward(self, x):
    x = self.initial(x)
    for layer in self.layers:
      x = layer(x)
    return torch.sigmoid(x)    

In [4]:
class ConvBlock(nn.Module):
  def __init__(self, nb_in_layers, nb_out_layers, down=True, use_act=True, **kwargs):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(nb_in_layers, nb_out_layers, padding_mode="reflect", **kwargs) if down
        else nn.ConvTranspose2d(nb_in_layers, nb_out_layers, **kwargs),
        nn.ReLU(inplace=True) if use_act else nn.Identity()
    )

  def forward(self, x):
    return self.conv(x)

In [5]:
class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.block = nn.Sequential(
        ConvBlock(channels, channels, kernel_size=3, padding=1),
        ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
    )
  
  def forward(self, x):
    return x + self.block(x)

In [6]:
class Generator(nn.Module):
  def __init__(self, img_channels, num_residuals=9):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
        nn.ReLU(inplace=True)
    )
    self.down_path = nn.ModuleList([
        ConvBlock(64, 128, kernel_size=3, down=True, stride=2, padding=1),
        ConvBlock(128, 256, kernel_size=3, down=True, stride=2, padding=1),
    ])
    self.residual_blocks = nn.Sequential(
       *[ResidualBlock(256) for _ in range(num_residuals)] 
    )
    self.up_path = nn.ModuleList([
        ConvBlock(256, 128, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ConvBlock(128, 64, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
    ])
    self.colorize = nn.Conv2d(64, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
  
  def forward(self, x):
    x = self.initial(x)
    for layer in self.down_path:
      x = layer(x)
    x = self.residual_blocks(x)
    for layer in self.up_path:
      x = layer(x)
    return torch.tanh(self.colorize(x))

In [7]:
class UnetGenerator(nn.Module):
  def __init__(self, img_channels, num_residuals=9):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
        nn.ReLU(inplace=True)
    )
    self.down_path = nn.ModuleList([
        ConvBlock(64, 128, kernel_size=3, down=True, stride=2, padding=1),
        ConvBlock(128, 256, kernel_size=3, down=True, stride=2, padding=1),
    ])
    self.residual_blocks = nn.Sequential(
       *[ResidualBlock(256) for _ in range(num_residuals)] 
    )
    self.up_path = nn.ModuleList([
        ConvBlock(256, 128, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ConvBlock(128*2, 64, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
    ])
    self.colorize = nn.Conv2d(64*2, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
  
  def forward(self, x):
    x = self.initial(x)
    skip_connections = []
    skip_connections.append(x)
    for layer in self.down_path:
      x = layer(x)
      skip_connections.append(x)
    x = self.residual_blocks(x)
    for idx in range(len(self.up_path)):
      x = self.up_path[idx](x)
      x= torch.concat((skip_connections[-(idx+1)-1],x),1)
    return torch.tanh(self.colorize(x))

In [8]:
class DoubleConvBlock(nn.Module):
    def __init__(self, nb_in_layers, nb_out_layers):
        super(DoubleConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(nb_in_layers, nb_out_layers, 3, 1, 1, padding_mode="reflect", bias=False),
            nn.InstanceNorm2d(nb_out_layers),
            nn.ReLU(inplace=True),
            nn.Conv2d(nb_out_layers, nb_out_layers, 3, 1, 1, padding_mode="reflect", bias=False),
            nn.InstanceNorm2d(nb_out_layers),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UnetGeneratorAdvanced(nn.Module):
    def __init__(
            self, img_channels=3, features=[32, 64, 128], num_residuals=6
    ):
        super(UnetGeneratorAdvanced, self).__init__()
        self.up_path = nn.ModuleList()
        self.down_path = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # On augmente le nombre de layers et on diminue la taille des images
        in_channels = img_channels
        for feature in features:
            self.down_path.append(DoubleConvBlock(in_channels, feature))
            in_channels = feature

        # On augmente la taille des images et on dimini=ue le nombre de layers
        for feature in reversed(features):
            self.up_path.append(
                nn.ConvTranspose2d( # pour faire le upsampling
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.up_path.append(DoubleConvBlock(feature*2, feature))

        #self.bottleneck = DoubleConvBlock(features[-1], features[-1]*2)
        self.bottleneck = nn.Sequential(
            DoubleConvBlock(features[-1], features[-1]*2),
            *[ResidualBlock(features[-1]*2) for _ in range(num_residuals-1)] )
        self.colorize = nn.Conv2d(features[0], img_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.down_path:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.up_path), 2):
            x = self.up_path[idx](x)
            skip_connection = skip_connections[idx//2]
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.up_path[idx+1](concat_skip)

        return self.colorize(x)

# Chargement des modèles

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
# Modifier ici pour que ça corresponde au modèle à tester
disc_A = Discriminator(nb_in_layers=3).to(device) # vrai Cheval ou génération?
disc_B = Discriminator(nb_in_layers=3).to(device) # vrai Zebre ou génération ? 
gen_A = Generator(img_channels=3, num_residuals=9).to(device)
gen_B = Generator(img_channels=3, num_residuals=9).to(device)

In [15]:
opt_disc = optim.Adam(
    list(disc_A.parameters()) + list(disc_B.parameters()),
    lr = 1e-4,
    betas=(0.9, 0.999)
)
opt_gen = optim.Adam(
    list(gen_A.parameters()) + list(gen_B.parameters()),
    lr=1e-4,
    betas=(0.9, 0.999)
)

In [16]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


In [17]:
from google.colab import drive
drive.mount('/content/drive')

%cd "/content/drive/MyDrive/INF8225 Groupe/PROJET"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/INF8225 Groupe/PROJET


In [18]:
load_checkpoint("saved_model2/sans_cycle_loss_ZH_60_disc_A.pth.tar", disc_A, opt_disc, 1e-4)
load_checkpoint("saved_model2/sans_cycle_loss_ZH_60_disc_B.pth.tar", disc_B, opt_disc, 1e-4)
load_checkpoint("saved_model2/sans_cycle_loss_ZH_60_gen_A.pth.tar", gen_A, opt_gen, 1e-4)
load_checkpoint("saved_model2/sans_cycle_loss_ZH_60_gen_B.pth.tar", gen_B, opt_gen, 1e-4)

=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint


# Chargement des jeux de images à changer

In [19]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

In [20]:
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

In [21]:
zebra_dir_path = "horse2zebra/testB"
horse_dir_path = "horse2zebra/testA"

In [22]:
zebra_imgs = os.listdir(zebra_dir_path)
horse_imgs = os.listdir(horse_dir_path)


In [23]:
from torch.utils.data import Dataset, DataLoader

class Custom_AB_Dataset(Dataset):
  def __init__(self, root_A, root_B):
    self.root_A = root_A
    self.root_B = root_B
    
    self.A_images = os.listdir(root_A)
    self.B_images = os.listdir(root_B)
    self.A_length = len(self.A_images)
    self.B_length = len(self.B_images)
    self.length_dataset = max(self.A_length, self.B_length)
    #self.length_dataset = 50 # mettre à genre 50 ou 200 pour vérifier rapidement que le code tourne sur plusieurs epochs

  def __len__(self):
    return self.length_dataset

  def __getitem__(self, index):
    A_img = self.A_images[index%self.A_length] # certains exemples (indices plus faibles) risquent d'être montrés une fois de plus à chaque epoch que les autres
    B_img = self.B_images[index%self.B_length]

    A_path = os.path.join(self.root_A, A_img)
    B_path = os.path.join(self.root_B, B_img)


    A_img = np.array(Image.open(A_path).convert("RGB"))
    B_img = np.array(Image.open(B_path).convert("RGB"))

    A_img = transforms(image=A_img)["image"]
    B_img = transforms(image=B_img)["image"]
    return A_img, B_img

test_dataset = Custom_AB_Dataset("horse2zebra/testA", "horse2zebra/testB")

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
)

# Inférences

In [24]:
from PIL import Image
import matplotlib.pyplot as plt

In [25]:
# Je sais pas pourquoi avec les générateurs unet et base+skip-connections, il veut pas passer générer avec la boucle que j'ai faite à la base
# Petite fonction pour générer depuis un dataloader. Bon y a certaines générations du jeu de données le plus petit qui vont apparaître en double par contre...
import tqdm
def gen_images(gen_A, gen_B, loader):
  loop = tqdm.tqdm(loader, leave=True)
  for idx, (img_A, img_B) in enumerate(loop):
    img_A = img_A.to(device)
    img_B = img_B.to(device)
    fake_A = gen_A(img_B)
    fake_B = gen_B(img_A)
    save_image(0.5*fake_A+0.5, f"horse2zebra/horse_gen/_{idx}.png")
    save_image(0.5*fake_B+0.5, f"horse2zebra/zebra_gen/_{idx}.png")


In [None]:
# quand on travaille avec unet ou generateur avec skip-connections
gen_images(gen_A, gen_B, test_loader)

In [27]:
# quand on travaille avec générateur de base
for img_file in zebra_imgs:
  img0 = np.array(Image.open(zebra_dir_path+"/"+img_file).convert("RGB"))
  augmentations = transforms(image=img0)
  img = augmentations['image'].to(device)
  h = gen_A(img)
  save_image(0.5*h+0.5, f"horse2zebra/horse_gen/{img_file}")


In [28]:
# quand on travaille avec générateur de base
for img_file in horse_imgs:
  img0 = np.array(Image.open(horse_dir_path+"/"+img_file).convert("RGB"))
  augmentations = transforms(image=img0)
  img = augmentations['image'].to(device)
  z = gen_B(img)
  save_image(0.5*z+0.5, f"horse2zebra/zebra_gen/{img_file}")

In [29]:
!pwd

/content/drive/MyDrive/INF8225 Groupe/PROJET


# Calcul FID Score

In [30]:
!pip install pytorch-fid

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl (15 kB)
Installing collected packages: pytorch-fid
Successfully installed pytorch-fid-0.3.0


Dataset avec un A → c'est des chevaux

Dataset avec un B → c'est des zèbres

### Distance génération zebres / zebres de train

In [31]:
!python -m pytorch_fid horse2zebra/trainB horse2zebra/zebra_gen # FID score pour generation de zebres

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100% 91.2M/91.2M [00:02<00:00, 46.8MB/s]
100% 27/27 [00:23<00:00,  1.17it/s]
100% 2/2 [00:00<00:00,  3.24it/s]
FID:  219.25219956468985


### Distance génération zebres / chevaux utilisés pour générer

In [32]:
!python -m pytorch_fid horse2zebra/testA horse2zebra/zebra_gen # A quel point c'est resté des chevaux?

100% 2/2 [00:00<00:00,  2.14it/s]
100% 2/2 [00:00<00:00,  3.16it/s]
FID:  338.6357691060628


### Distance génération chevaux / chevaux

In [33]:
!python -m pytorch_fid horse2zebra/trainA horse2zebra/horse_gen # FID score pour generation de chevaux

100% 22/22 [00:19<00:00,  1.10it/s]
100% 2/2 [00:00<00:00,  2.55it/s]
FID:  229.86894079246275


### Distance génération chevaux / zèbres utilisés pour générer

In [34]:
!python -m pytorch_fid horse2zebra/testB horse2zebra/horse_gen # a quel point les chevaux n'ont pas été changés? A quel point c'est resté des zebres?

100% 2/2 [00:01<00:00,  1.89it/s]
100% 2/2 [00:00<00:00,  2.60it/s]
FID:  324.25504745078604


## Distances entre les datasets d'origine

In [None]:
!python -m pytorch_fid horse2zebra/testB horse2zebra/trainA # similarité des photos de zebres avec les photos de chevaux

100% 2/2 [00:02<00:00,  1.07s/it]
100% 22/22 [00:10<00:00,  2.10it/s]
FID:  223.14844778846697


In [None]:
!python -m pytorch_fid horse2zebra/testA horse2zebra/trainB # l'inverse

100% 2/2 [00:00<00:00,  2.09it/s]
100% 27/27 [00:08<00:00,  3.13it/s]
FID:  244.5187583190737


In [None]:
!python -m pytorch_fid horse2zebra/trainB horse2zebra/testB # similarité zebre zebre original dataset, c'est proche de 0 donc ok

100% 27/27 [00:10<00:00,  2.61it/s]
100% 2/2 [00:00<00:00,  2.41it/s]
FID:  32.960015374379466


In [None]:
!python -m pytorch_fid horse2zebra/trainA horse2zebra/testA # similarité zebre zebre original dataset, c'est proche de 0 donc ok

100% 22/22 [00:07<00:00,  2.80it/s]
100% 2/2 [00:00<00:00,  2.79it/s]
FID:  102.75344179669153
