In [1]:
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
from collections import Counter, defaultdict
import wandb
import tqdm
from torchvision.utils import save_image

In [2]:
device

device(type='cuda')

In [3]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config['device'])
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


# Définition du modèle

In [4]:
class Block(nn.Module):
  def __init__(self, nb_in_layers, nb_out_layers, stride):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(nb_in_layers, nb_out_layers, 4, stride, padding=1, bias=True, padding_mode="reflect"),
        nn.InstanceNorm2d(nb_out_layers),
        nn.LeakyReLU(0.2),
    )

  def forward(self, x):
    return self.conv(x)

In [5]:
class Discriminator(nn.Module):
  def __init__(self, nb_in_layers=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(nb_in_layers, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect",),
        nn.LeakyReLU(0.2),
        )
    self.layers = nn.ModuleList()
    in_channels = features[0]
    for idx, feature in enumerate(features[1:]):
      if feature == features[-1]:
        self.layers.add_module(f'Block{idx}_{feature}', Block(in_channels, feature, stride=1))
      else:
        self.layers.add_module(f'Block{idx}_{feature}', Block(in_channels, feature, stride=2))
      in_channels = feature
    self.layers.add_module('OutputMapping', nn.Conv2d(in_channels, out_channels=1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))

  def forward(self, x):
    x = self.initial(x)
    for layer in self.layers:
      x = layer(x)
    return torch.sigmoid(x)    

In [6]:
# test
x = torch.randn((5,3,256,256)) # 5 images, en RGB, de taille 256 256
model = Discriminator()
preds = model(x)
print(preds.shape) # on s'attend à 5 images, 1 seul canal, taille 30 par 30

torch.Size([5, 1, 30, 30])


In [7]:
class ConvBlock(nn.Module):
  def __init__(self, nb_in_layers, nb_out_layers, down=True, use_act=True, **kwargs):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(nb_in_layers, nb_out_layers, padding_mode="reflect", **kwargs) if down
        else nn.ConvTranspose2d(nb_in_layers, nb_out_layers, **kwargs),
        nn.ReLU(inplace=True) if use_act else nn.Identity()
    )

  def forward(self, x):
    return self.conv(x)

In [8]:
class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.block = nn.Sequential(
        ConvBlock(channels, channels, kernel_size=3, padding=1),
        ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
    )
  
  def forward(self, x):
    return x + self.block(x)

In [9]:
class Generator(nn.Module):
  def __init__(self, img_channels, num_residuals=9):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
        nn.ReLU(inplace=True)
    )
    self.down_path = nn.ModuleList([
        ConvBlock(64, 128, kernel_size=3, down=True, stride=2, padding=1),
        ConvBlock(128, 256, kernel_size=3, down=True, stride=2, padding=1),
    ])
    self.residual_blocks = nn.Sequential(
       *[ResidualBlock(256) for _ in range(num_residuals)] 
    )
    self.up_path = nn.ModuleList([
        ConvBlock(256, 128, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ConvBlock(128, 64, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
    ])
    self.colorize = nn.Conv2d(64, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
  
  def forward(self, x):
    x = self.initial(x)
    for layer in self.down_path:
      x = layer(x)
    x = self.residual_blocks(x)
    for layer in self.up_path:
      x = layer(x)
    return torch.tanh(self.colorize(x))

In [10]:
class UnetGenerator(nn.Module):
  def __init__(self, img_channels, num_residuals=9):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
        nn.ReLU(inplace=True)
    )
    self.down_path = nn.ModuleList([
        ConvBlock(64, 128, kernel_size=3, down=True, stride=2, padding=1),
        ConvBlock(128, 256, kernel_size=3, down=True, stride=2, padding=1),
    ])
    self.residual_blocks = nn.Sequential(
       *[ResidualBlock(256) for _ in range(num_residuals)] 
    )
    self.up_path = nn.ModuleList([
        ConvBlock(256, 128, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ConvBlock(128*2, 64, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
    ])
    self.colorize = nn.Conv2d(64*2, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
  
  def forward(self, x):
    x = self.initial(x)
    skip_connections = []
    skip_connections.append(x)
    for layer in self.down_path:
      #print("down size", x.shape)
      x = layer(x)
      skip_connections.append(x)
    #concat_input = torch.concat([skip_connections[-1], x], 1)
    x = self.residual_blocks(x)
    for idx in range(len(self.up_path)):
      #print("x in shape",x.shape)
      x = self.up_path[idx](x)
      #print("x output shape",x.shape)
      #print("sc shape ", skip_connections[-(idx+1)-1].shape)
      x= torch.concat((skip_connections[-(idx+1)-1],x),1)
    return torch.tanh(self.colorize(x))

In [11]:
# test
x = torch.randn((5,3,256,256)) # 5 images, en RGB, de taille 256 256
model = Generator(img_channels=3)
preds = model(x)
print(preds.shape) # on s'attend à 5 images, 3 canaux, taille 256 par 256

torch.Size([5, 3, 256, 256])


In [12]:
# test
x = torch.randn((5,3,256,256)) # 5 images, en RGB, de taille 256 256
model = UnetGenerator(img_channels=3)
preds = model(x)
print(preds.shape) # on s'attend à 5 images, 3 canaux, taille 256 par 256

torch.Size([5, 3, 256, 256])


### Advanced U-Net Generator 

In [13]:
class DoubleConvBlock(nn.Module):
    def __init__(self, nb_in_layers, nb_out_layers):
        super(DoubleConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(nb_in_layers, nb_out_layers, 3, 1, 1, padding_mode="reflect", bias=False),
            nn.InstanceNorm2d(nb_out_layers),
            nn.ReLU(inplace=True),
            nn.Conv2d(nb_out_layers, nb_out_layers, 3, 1, 1, padding_mode="reflect", bias=False),
            nn.InstanceNorm2d(nb_out_layers),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [14]:
class UnetGeneratorAdvanced(nn.Module):
    def __init__(
            self, img_channels=3, features=[32, 64, 128], num_residuals=6
    ):
        super(UnetGeneratorAdvanced, self).__init__()
        self.up_path = nn.ModuleList()
        self.down_path = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # On augmente le nombre de layers et on diminue la taille des images
        in_channels = img_channels
        for feature in features:
            self.down_path.append(DoubleConvBlock(in_channels, feature))
            in_channels = feature

        # On augmente la taille des images et on dimini=ue le nombre de layers
        for feature in reversed(features):
            self.up_path.append(
                nn.ConvTranspose2d( # pour faire le upsampling
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.up_path.append(DoubleConvBlock(feature*2, feature))

        #self.bottleneck = DoubleConvBlock(features[-1], features[-1]*2)
        self.bottleneck = nn.Sequential(
            DoubleConvBlock(features[-1], features[-1]*2),
            *[ResidualBlock(features[-1]*2) for _ in range(num_residuals-1)] )
        self.colorize = nn.Conv2d(features[0], img_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.down_path:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.up_path), 2):
            x = self.up_path[idx](x)
            skip_connection = skip_connections[idx//2]

            #if x.shape != skip_connection.shape:
            #    x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.up_path[idx+1](concat_skip)

        return self.colorize(x)

In [15]:
# test
x = torch.randn((5,3,256,256)) # 5 images, en RGB, de taille 256 256
model = UnetGeneratorAdvanced(img_channels=3)
preds = model(x)
print(preds.shape) # on s'attend à 5 images, 3 canaux, taille 256 par 256

torch.Size([5, 3, 256, 256])


# Training Loop

In [16]:
def train_epoch(disc_A, disc_B, gen_A, gen_B, loader, opt_disc, opt_gen, config, d_scaler, g_scaler, num_epoch):
  """
  Entraînement sur 1 epoch
  """
  loop = tqdm.tqdm(loader, leave=True)
  logs = defaultdict(list)
  device = config['device']
  A_name = config['A_name']
  B_name = config['B_name']
  cycle_loss_fun = config['cycle_loss'] # à changer plus tard, c'est à dire qu'il faut renplacer mse et L1 dans le code ci dessous
  adversarial_loss_fun = config['adversarial_loss']
  if config['identity_lambda']>0:
    identity_loss_fun = config['identity_loss'].to(device)
  for idx, (img_A, img_B) in enumerate(loop):
    metrics = dict()
    img_A = img_A.to(config["device"])
    img_B = img_B.to(config["device"])

    # Train discriminators
    with torch.cuda.amp.autocast():
      # loss du discriminateur vraie/fausse img_B
      #print("img_A shape", img_A.shape)
      fake_img_B = gen_B(img_A)
      D_B_real = disc_B(img_B) 
      D_B_fake = disc_B(fake_img_B.detach())
      D_B_real_loss = adversarial_loss_fun(D_B_real, torch.ones_like(D_B_real))
      D_B_fake_loss = adversarial_loss_fun(D_B_fake, torch.zeros_like(D_B_fake))
      D_B_loss = D_B_real_loss + D_B_fake_loss
      # loss du discriminateur vrai/faux img_A
      fake_img_A = gen_A(img_B)
      D_A_real = disc_A(img_A)
      D_A_fake = disc_A(fake_img_A.detach())
      D_A_real_loss = adversarial_loss_fun(D_A_real, torch.ones_like(D_A_real))
      D_A_fake_loss = adversarial_loss_fun(D_A_fake, torch.zeros_like(D_A_fake))
      D_A_loss = D_A_real_loss + D_A_fake_loss
      # rassemble les 2 
      D_loss = (D_A_loss + D_B_loss)/2 # important de diviser par 2? jsp
    opt_disc.zero_grad() # remet le gradient à 0
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()

    # Train Generators
    with torch.cuda.amp.autocast():
      # adversarial losses
      D_A_fake = disc_A(fake_img_A)
      D_B_fake = disc_B(fake_img_B)
      G_A_loss = adversarial_loss_fun(D_A_fake, torch.ones_like(D_A_fake))
      G_B_loss = adversarial_loss_fun(D_B_fake, torch.ones_like(D_B_fake))

      # cycle loss
      cycle_B = gen_B(fake_img_A)
      cycle_A = gen_A(fake_img_B)
      cycle_B_loss = cycle_loss_fun(img_B, cycle_B)
      cycle_A_loss = cycle_loss_fun(img_A, cycle_A)

      # ajouter identity loss
      if config['identity_lambda']>0:
        #On vérifie que G(y) = y et F(x) = x i.e. que les générateurs soient des fonctions identités
        #par rapport à ce qu'ils doivent reproduire. Donc le générateur de Monet doit être l'identité
        #si on lui passe en entrée un Monet.
        B_identity_loss = identity_loss_fun(gen_B(img_B), img_B)
        A_identity_loss = identity_loss_fun(gen_A(img_A), img_A)
        identity_loss = B_identity_loss + A_identity_loss
      else : identity_loss = 0

      # rassemble tout
      G_loss = (G_A_loss + G_B_loss + 
                (cycle_B_loss + cycle_A_loss)*config["cycle_lambda"] +
                identity_loss * config["identity_lambda"]
                )
    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()
    

    metrics['D_B_real_loss'] = D_B_real_loss.item()
    metrics['D_B_fake_loss'] = D_B_fake_loss.item()
    metrics['D_B_loss'] = D_B_loss.item()
    metrics['D_A_real_loss'] = D_A_real_loss.item()
    metrics['D_A_fake_loss'] = D_A_fake_loss.item()
    metrics['D_A_loss'] = D_A_loss.item()
    metrics['D_loss'] = D_loss.item()
    metrics['G_A_loss'] = G_A_loss.item()
    metrics['G_B_loss'] = G_B_loss.item()
    metrics['cycle_B_loss'] = cycle_B_loss.item()
    metrics['cycle_A_loss'] = cycle_A_loss.item()
    metrics['G_loss'] = G_loss.item()

    for name, value in metrics.items():
      logs[name].append(value)

    if idx % config['log_every'] == 0:
      for name, value in logs.items():
          logs[name] = np.mean(value)
      train_logs = {
          f'Train - {m}': v
          for m, v in logs.items() }
      wandb.log(train_logs)
      logs = defaultdict(list)
    

    if idx % config['show_image_every'] == 0:
      save_image(0.5*fake_img_A+0.5, f"saved_images/{A_name}_epoch_{num_epoch}_indx_{idx}.png")
      save_image(0.5*fake_img_B+0.5, f"saved_images/{B_name}_epoch_{num_epoch}_indx_{idx}.png")

  # a la fin de l'epoch, enregistrer des images choisies     
  for selected_img_A, selected_img_B, sel_img_A_path, sel_img_B_path in zip(config["selected_img_A"], config["selected_img_B"], selected_img_A_path, selected_img_B_path):
    #print("selected_img_A shape", selected_img_A.shape)
    save_image(0.5*gen_B(selected_img_A)+0.5, f"selected_images_end_epoch/{config['model_name']}_{B_name}_{sel_img_A_path[18:-4]}_epoch_{num_epoch}.png")
    save_image(0.5*gen_A(selected_img_B)+0.5, f"selected_images_end_epoch/{config['model_name']}_{A_name}_{sel_img_B_path[18:-4]}_epoch_{num_epoch}.png")


In [17]:


def eval_metrics(disc_A, disc_B, gen_A, gen_B, img_A, img_B, config):
    """
    Calcul des différentes loss. On va utiliser ça que en validation parce qu'on va calculer toutes les loss ici, et ne pas faire de backward
    Or, on fait un premier backward sur les discriminateurs avant de calculer les loss des générateurs
    Je suis pas sur qu'on ait besoin de faire comme ça, mais bon j'ai vu cette implémentation
    Je pense qu'en vrai on pourrait l'utiliser pour le train aussi
    Ajouter une métrique d'accuracy sur les deux discriminateurs? mais ils discriminent par patch: vote moyen? vote majoritaire?
    """
    device = config['device']
    adversarial_loss = config['adversarial_loss'].to(device)
    cycle_loss = config['cycle_loss'].to(device)
    if config['identity_lambda']>0:
      identity_loss = config['identity_loss'].to(device)
    metrics = dict()

    with torch.cuda.amp.autocast():
      # on genere un objet de la classe B a partir d'un objet de A, et un objet de A à partir d'un objet de B
      fake_B = gen_B(img_A)
      D_B_real = disc_B(img_B) 
      D_B_fake = disc_B(fake_B.detach())
      # loss du discriminateur vrai/faux objet B
      D_B_real_loss = adversarial_loss(D_B_real, torch.ones_like(D_B_real))
      D_B_fake_loss = adversarial_loss(D_B_fake, torch.zeros_like(D_B_fake))
      D_B_loss = D_B_real_loss + D_B_fake_loss
      # loss du discriminateur vrai/faux objet A
      fake_A = gen_A(img_B)
      D_A_real = disc_A(img_A)
      D_A_fake = disc_A(fake_A.detach())
      D_A_real_loss = adversarial_loss(D_A_real, torch.ones_like(D_A_real))
      D_A_fake_loss = adversarial_loss(D_A_fake, torch.zeros_like(D_A_fake))
      D_A_loss = D_A_real_loss + D_A_fake_loss
      # rassemble les 2 
      D_loss = (D_A_loss + D_B_loss)/2 # important de diviser par 2? jsp

    with torch.cuda.amp.autocast():
      # adversarial losses
      D_A_fake = disc_A(fake_A)
      D_B_fake = disc_B(fake_B)
      G_A_loss = adversarial_loss(D_A_fake, torch.ones_like(D_A_fake))
      G_B_loss = adversarial_loss(D_B_fake, torch.ones_like(D_B_fake))

      # cycle loss
      cycle_B = gen_B(fake_A)
      cycle_A = gen_A(fake_B)
      cycle_B_loss = cycle_loss(img_B, cycle_B)
      cycle_A_loss = cycle_loss(img_A, cycle_A)

      # ajouter identity loss
      if config['identity_lambda']>0:
        #On vérifie que G(y) = y et F(x) = x i.e. que les générateurs soient des fonctions identités
        #par rapport à ce qu'ils doivent reproduire. Donc le générateur de Monet doit être l'identité
        #si on lui passe en entrée un Monet.
        B_identity_loss = identity_loss(gen_B(img_B), img_B)
        A_identity_loss = identity_loss(gen_A(img_A), img_A)
        total_identity_loss = B_identity_loss + A_identity_loss
      else : total_identity_loss = 0

      # rassemble tout
      G_loss = (G_A_loss + G_B_loss + 
                (cycle_B_loss + cycle_A_loss)*config["cycle_lambda"] +
                total_identity_loss * config["identity_lambda"]
                )
      
    metrics['D_B_real_loss'] = D_B_real_loss.item()
    metrics['D_B_fake_loss'] = D_B_fake_loss.item()
    metrics['D_B_loss'] = D_B_loss.item()
    metrics['D_A_real_loss'] = D_A_real_loss.item()
    metrics['D_A_fake_loss'] = D_A_fake_loss.item()
    metrics['D_A_loss'] = D_A_loss.item()
    metrics['D_loss'] = D_loss.item()
    metrics['G_A_loss'] = G_A_loss.item()
    metrics['G_B_loss'] = G_B_loss.item()
    metrics['cycle_B_loss'] = cycle_B_loss.item()
    metrics['cycle_A_loss'] = cycle_A_loss.item()
    metrics['G_loss'] = G_loss.item()
    
    return metrics



def eval(disc_A, disc_B, gen_A, gen_B, dataloader, config):
  device = config['device']
  disc_A.to(device)
  disc_B.to(device)
  gen_A.to(device)
  gen_B.to(device)
  disc_A.eval()
  disc_B.eval()
  gen_A.eval()
  gen_B.eval()
  logs = defaultdict(list)
  with torch.no_grad():
    for img_A, img_B in dataloader:
      img_A = img_A.to(device)
      img_B = img_B.to(device)
      metrics = eval_metrics(disc_A, disc_B, gen_A, gen_B, img_A, img_B, config)
    for name, value in metrics.items():
        logs[name].append(value)
  return logs

In [18]:
def train(disc_A, disc_B, gen_A, gen_B, train_loader, val_loader, opt_disc, opt_gen, config, d_scaler, g_scaler):
  
  for n in range(config['epochs']):
    # peut etre changer le learning rate
    if n==55:
      config['lr'] = config['lr']/4
      opt_disc.lr = config['lr']
      opt_gen.lr = config['lr']/2
    if n==80:
      config['lr'] = config['lr']/2
      opt_disc.lr = config['lr']
      opt_gen.lr = config['lr']/2
    train_epoch(disc_A, disc_B, gen_A, gen_B, train_loader, opt_disc, opt_gen, config, d_scaler, g_scaler, n)

    logs = eval(disc_A, disc_B, gen_A, gen_B, val_loader, config)
    for name, value in logs.items():
        logs[name] = np.mean(value)
    val_logs = {
        f'Validation - {m}': v
        for m, v in logs.items()
    }
    print("val logs:",val_logs)
    wandb.log(val_logs)
    save_checkpoint(gen_A, opt_gen, filename="saved_model/gen_A.pth.tar")
    save_checkpoint(gen_B, opt_gen, filename="saved_model/gen_B.pth.tar")
    save_checkpoint(disc_A, opt_disc, filename="saved_model/disc_A.pth.tar")
    save_checkpoint(disc_B, opt_disc, filename="saved_model/disc_B.pth.tar")


# Config

In [21]:

# model_name_numeroepoch_numeroidx.jpg


config = {
    'model_name' : "Classic",
    'A_name' : 'horse',
    'B_name' : 'zebra',
    'epochs': 120,
    'log_every': 350,
    'show_image_every':500,
    'batch_size': 1,
    'lr': 2e-4,
    'betas': (0.9, 0.999),
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'cycle_lambda' : 10,
    'identity_lambda' : 0,
    'adversarial_loss' : nn.MSELoss(),
    'cycle_loss' : nn.L1Loss(),
    'identity_loss' : nn.L1Loss(),
}


# Chargement des données

In [22]:
from PIL import Image
from torchvision import transforms
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [23]:
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ]
)

In [24]:
class Custom_AB_Dataset(Dataset):
  def __init__(self, root_A, root_B):
    self.root_A = root_A
    self.root_B = root_B
    
    self.A_images = os.listdir(root_A)
    self.B_images = os.listdir(root_B)
    self.A_length = len(self.A_images)
    self.B_length = len(self.B_images)
    self.length_dataset = max(self.A_length, self.B_length)
    #self.length_dataset = 50 # mettre à genre 50 ou 200 pour vérifier rapidement que le code tourne sur plusieurs epochs

  def __len__(self):
    return self.length_dataset

  def __getitem__(self, index):
    A_img = self.A_images[index%self.A_length] # certains exemples (indices plus faibles) risquent d'être montrés une fois de plus à chaque epoch que les autres
    B_img = self.B_images[index%self.B_length]

    A_path = os.path.join(self.root_A, A_img)
    B_path = os.path.join(self.root_B, B_img)


    A_img = np.array(Image.open(A_path).convert("RGB"))
    B_img = np.array(Image.open(B_path).convert("RGB"))

    A_img = transforms(image=A_img)["image"]
    B_img = transforms(image=B_img)["image"]
    return A_img, B_img

In [25]:
#Construction de la liste qui va contenir les images selectionnees
selected_horse_path = ["horse2zebra/testA/n02381460_1110.jpg", 
 "horse2zebra/testA/n02381460_1820.jpg",
 "horse2zebra/testA/n02381460_4550.jpg",
 "horse2zebra/testA/n02381460_7170.jpg",
 "horse2zebra/testA/n02381460_7400.jpg"]
selected_zebra_path = ["horse2zebra/testB/n02391049_490.jpg",
                   "horse2zebra/testB/n02391049_1430.jpg",
                   "horse2zebra/testB/n02391049_1220.jpg",
                   "horse2zebra/testB/n02391049_6780.jpg",
                   "horse2zebra/testB/n02391049_10100.jpg"]
selected_horse, selected_zebra = [], []

for sel_horse_path, sel_zebra_path in zip(selected_horse_path, selected_zebra_path):
      sel_horse = Image.open(sel_horse_path).convert("RGB")
      sel_zebra = Image.open(sel_zebra_path).convert("RGB")

      sel_horse = np.array(sel_horse)
      sel_zebra = np.array(sel_zebra)

      augmentations = transforms(image = sel_horse, image0 = sel_zebra)

      selected_horse.append(transforms(image=sel_horse)["image"].to(device)[None,:])
      selected_zebra.append(transforms(image=sel_zebra)["image"].to(device)[None,:])


selected_img_A_path = selected_horse_path
selected_img_B_path = selected_zebra_path
config['selected_img_A'] = selected_horse
config['selected_img_B'] = selected_zebra

In [26]:
!ls

'ls' n'est pas reconnu en tant que commande interne
ou externe, un programme ex�cutable ou un fichier de commandes.


In [27]:
train_dataset = Custom_AB_Dataset("horse2zebra/trainA", "horse2zebra/trainB")

In [28]:
len(train_dataset)

1334

In [29]:
train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=0,
)

In [30]:
val_dataset = Custom_AB_Dataset("horse2zebra/valA", "horse2zebra/valB")
len(val_dataset)

48

In [31]:
val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=0,
)

# Training the models

In [32]:
disc_A = Discriminator(nb_in_layers=3).to(config['device']) # vrai Cheval ou génération?
disc_B = Discriminator(nb_in_layers=3).to(config['device']) # vrai Zebre ou génération ? 
gen_A = Generator(img_channels=3, num_residuals=9).to(config['device'])
gen_B = Generator(img_channels=3, num_residuals=9).to(config['device'])

In [33]:
import torch.optim as optim

opt_disc = optim.Adam(
    list(disc_A.parameters()) + list(disc_B.parameters()),
    lr = config['lr']/2,
    betas=config['betas']
)
opt_gen = optim.Adam(
    list(gen_A.parameters()) + list(gen_B.parameters()),
    lr=config['lr'],
    betas=config['betas']
)

In [34]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [35]:
# Checking GPU and logging to wandb

!wandb login

!nvidia-smi

wandb: Currently logged in as: erichuard-second (pasdutoutlate). Use `wandb login --relogin` to force relogin


Sun Apr 30 00:38:28 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 531.14                 Driver Version: 531.14       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                      TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 L...  WDDM | 00000000:01:00.0  On |                  N/A |
| N/A   37C    P8               15W /  N/A|   1492MiB /  6144MiB |     21%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
!wandb online  # online / offline to activate or deactivate WandB logging

with wandb.init(
        config=config,
        project='NOTRE PROJET',  # Title of your project
        group='ZebreCheval',  # In what group of runs do you want this run to be in?
        save_code=True,
        name="Classic",
    ):
    train(disc_A, disc_B, gen_A, gen_B, train_loader, val_loader, opt_disc, opt_gen, config, d_scaler, g_scaler)