In [1]:
# Discriminator

import torch
import torch.nn as nn


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride,
                      1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.inital = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

				# Compile Disc Layers
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature,
                          stride=1 if feature == features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4,
                      stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
      x = self.inital(x)
      return torch.sigmoid(self.model(x))
        
def test():
  x = torch.randn((1, 3, 256, 256))
  model = Discriminator(in_channels=3)
  preds = model(x)
  print(preds.shape)
  
test()

torch.Size([1, 1, 30, 30])


In [2]:
# Generator

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs) if down else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False,
                      kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7,
                      stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True)
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, down=True,
                          kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features * 2, num_features * 4,
                          down=True, kernel_size=3, stride=2, padding=1),
            ]
        )

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )

        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features * 4, num_features * 2, down=False,
                          kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features * 2, num_features, down=False,
                          kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features, img_channels, kernel_size=7,
                              stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)

        return torch.tanh(self.last(x))


def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)


test()

torch.Size([2, 3, 256, 256])


In [3]:
# Utils

import random, torch, os, numpy as np
import torch.nn as nn
import copy
import albumentations as A
from albumentations.pytorch import ToTensorV2

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr, device):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

In [4]:
# Dataloader

from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset


class ABDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.root_A = root_A
        self.root_B = root_B
        self.transform = transform

        self.a_images = os.listdir(root_A)
        self.b_images = os.listdir(root_B)
        
        self.length_dataset = max(len(self.a_images), len(self.b_images))
        self.a_len = len(self.a_images)
        self.b_len = len(self.b_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        a_img = self.a_images[idx % self.a_len]
        b_img = self.b_images[idx % self.b_len]
        a_pth = os.path.join(self.root_A, a_img)
        b_pth = os.path.join(self.root_B, b_img)
        a_img = np.array(Image.open(a_pth).convert("RGB"))
        b_img = np.array(Image.open(b_pth).convert("RGB"))
        if self.transform:
            augs = self.transform(image=a_img, image0=b_img)
            a_img = augs["image"]
            b_img = augs["image0"]
        
        return a_img, b_img


In [5]:
# Config + Training

from torchvision.datasets import ImageFolder
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm
from torch.utils.data import DataLoader

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "./../_datasets/cyclegan/horse2zebra/horse2zebra"
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 0
NUM_EPOCHS = 150
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_B = "genB.pth.tar"  # B = Horse
CHECKPOINT_GEN_A = "genA.pth.tar"  # A = Zebra
CHECKPOINT_CRITIC_B = "criticB.pth.tar"
CHECKPOINT_CRITIC_A = "criticA.pth.tar"

def train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler, epoch):
    loop = tqdm(loader, leave=True)

    for idx, (a_img, b_img) in enumerate(loop):
        a_img = a_img.to(DEVICE)
        b_img = b_img.to(DEVICE)  # a_img = horse, b_img = zebra

        # Train Discriminators
        with torch.cuda.amp.autocast():
            fake_a = gen_B(b_img)
            D_A_real = disc_A(a_img)
            D_A_fake = disc_A(fake_a.detach())
            D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))
            D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake))
            D_A_loss = D_A_real_loss + D_A_fake_loss

            fake_b = gen_A(a_img)
            D_B_real = disc_B(b_img)
            D_B_fake = disc_B(fake_b.detach())
            D_B_real_loss = mse(D_B_real, torch.ones_like(D_B_real))
            D_B_fake_loss = mse(D_B_fake, torch.zeros_like(D_B_fake))
            D_B_loss = D_B_real_loss + D_B_fake_loss

            # Put the loss together
            D_loss = (D_A_loss + D_B_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators
        with torch.cuda.amp.autocast():
            # Adversarial Loss for Both Generators
            D_A_fake = disc_A(fake_a)
            D_B_fake = disc_B(fake_b)
            loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))
            loss_G_B = mse(D_B_fake, torch.ones_like(D_B_fake))

            # Cycle Loss
            cycle_b = gen_B(fake_a)
            cycle_a = gen_A(fake_b)
            cycle_b_loss = L1(b_img, cycle_b)
            cycle_a_loss = L1(a_img, cycle_a)

            # Identity Loss
            id_b = gen_B(b_img)
            id_a = gen_A(a_img)
            id_b_loss = L1(b_img, id_b)
            id_a_loss = L1(a_img, id_a)

            # Add together
            G_loss = (
                loss_G_B
                +
                loss_G_A
                +
                cycle_b_loss * LAMBDA_CYCLE
                +
                cycle_a_loss * LAMBDA_CYCLE
                +
                id_a_loss * LAMBDA_IDENTITY
                +
                id_b_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        loop.set_postfix(e=epoch, d_loss=D_loss.item(), g_loss=G_loss.item())
        # Print
        if idx % 100 == 0:
            save_image(torch.cat((a_img*0.5+0.5, b_img*0.5+0.5,fake_a*0.5+0.5, fake_b*0.5+0.5), 0), f"image/grid{epoch}-{idx}.png", nrow=2, normalize=True)


def main():
    disc_A = Discriminator(3).to(DEVICE)
    disc_B = Discriminator(3).to(DEVICE)
    gen_A = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_B = Generator(img_channels=3, num_residuals=9).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999)
    )

    opt_gen = optim.Adam(
        list(gen_A.parameters()) + list(gen_B.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999)
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_A, gen_A, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_B, gen_B, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_A, disc_A, opt_disc, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_B, disc_B, opt_disc, LEARNING_RATE,
        )

    dataset = ABDataset(root_A="./../_datasets/cyclegan/horse2zebra/horse2zebra/trainB",
                        root_B="./../_datasets/cyclegan/horse2zebra/horse2zebra/trainA", transform=transforms)

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc,
                 opt_gen, L1, mse, d_scaler, g_scaler, epoch)

        if SAVE_MODEL:
            save_checkpoint(gen_A, opt_gen, CHECKPOINT_GEN_A)
            save_checkpoint(gen_B, opt_gen, CHECKPOINT_GEN_B)
            save_checkpoint(disc_A, opt_disc, CHECKPOINT_CRITIC_A)
            save_checkpoint(disc_B, opt_disc, CHECKPOINT_CRITIC_B)


main()


100%|██████████| 1334/1334 [05:16<00:00,  4.22it/s, d_loss=0.496, e=0, g_loss=5.1] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:08<00:00,  4.33it/s, d_loss=0.362, e=1, g_loss=5.88]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:08<00:00,  4.33it/s, d_loss=0.375, e=2, g_loss=4.27]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:11<00:00,  4.29it/s, d_loss=0.755, e=3, g_loss=5.72] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:10<00:00,  4.30it/s, d_loss=0.263, e=4, g_loss=4.84] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:04<00:00,  4.38it/s, d_loss=0.231, e=5, g_loss=3.77] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:04<00:00,  4.39it/s, d_loss=0.182, e=6, g_loss=5.14] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:04<00:00,  4.39it/s, d_loss=0.214, e=7, g_loss=3.79]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.514, e=8, g_loss=5.09] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.39it/s, d_loss=0.359, e=9, g_loss=3.75] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.328, e=10, g_loss=3.73] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.314, e=11, g_loss=4.09] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.40it/s, d_loss=0.261, e=12, g_loss=4.05] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.0604, e=13, g_loss=3.53]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.40it/s, d_loss=0.209, e=14, g_loss=3.24] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.515, e=15, g_loss=3.43] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.111, e=16, g_loss=3.75] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.379, e=17, g_loss=3.08] 

=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.764, e=18, g_loss=2.47] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.40it/s, d_loss=0.321, e=19, g_loss=2.7]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.193, e=20, g_loss=3.32] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.242, e=21, g_loss=2.88] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.11, e=22, g_loss=3.42]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.0933, e=23, g_loss=3.57]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.428, e=24, g_loss=3.13] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.384, e=25, g_loss=3.21] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.422, e=26, g_loss=2.98] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.43it/s, d_loss=0.288, e=27, g_loss=3.34] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.228, e=28, g_loss=3.43] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.112, e=29, g_loss=3.03] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.426, e=30, g_loss=3.21] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.254, e=31, g_loss=3.13] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.141, e=32, g_loss=3.28] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.43it/s, d_loss=0.555, e=33, g_loss=3.7]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.43it/s, d_loss=0.415, e=34, g_loss=2.59] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.43it/s, d_loss=0.602, e=35, g_loss=2.92] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.253, e=36, g_loss=3.39] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.239, e=37, g_loss=3.33] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.43it/s, d_loss=0.269, e=38, g_loss=2.98] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.628, e=39, g_loss=3.07]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.466, e=40, g_loss=2.95] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.207, e=41, g_loss=3.37] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.233, e=42, g_loss=3.36] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.247, e=43, g_loss=2.8]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.192, e=44, g_loss=2.98] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.379, e=45, g_loss=2.84] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.43it/s, d_loss=0.331, e=46, g_loss=3.55] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.43it/s, d_loss=0.434, e=47, g_loss=3.09] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.164, e=48, g_loss=3.06] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.40it/s, d_loss=0.121, e=49, g_loss=2.99] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.271, e=50, g_loss=2.83] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.261, e=51, g_loss=3.68] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.335, e=52, g_loss=3.04] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.215, e=53, g_loss=3]    


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.582, e=54, g_loss=3.16] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.089, e=55, g_loss=3.03] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.166, e=56, g_loss=2.85] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.434, e=57, g_loss=2.71] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.218, e=58, g_loss=3.11] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.312, e=59, g_loss=2.68] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.359, e=60, g_loss=2.93] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.482, e=61, g_loss=2.68] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.249, e=62, g_loss=2.99] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.435, e=63, g_loss=2.59] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.378, e=64, g_loss=2.87] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.349, e=65, g_loss=2.78] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.578, e=66, g_loss=2.68] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.257, e=67, g_loss=2.73] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.444, e=68, g_loss=2.86] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.418, e=69, g_loss=2.76] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.398, e=70, g_loss=2.48] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.367, e=71, g_loss=3.12] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.285, e=72, g_loss=2.58] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.245, e=73, g_loss=2.87] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.528, e=74, g_loss=3.08] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.233, e=75, g_loss=2.58] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.322, e=76, g_loss=3.14] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.339, e=77, g_loss=2.86] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.113, e=78, g_loss=2.96] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:01<00:00,  4.42it/s, d_loss=0.168, e=79, g_loss=2.95] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.247, e=80, g_loss=2.94] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.138, e=81, g_loss=3.05] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.207, e=82, g_loss=3.14] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.128, e=83, g_loss=3.23] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.49, e=84, g_loss=2.69]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.145, e=85, g_loss=3.06] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.151, e=86, g_loss=2.9]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.281, e=87, g_loss=2.88] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.0542, e=88, g_loss=2.67]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.49, e=89, g_loss=2.78] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.17, e=90, g_loss=2.73]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.37, e=91, g_loss=2.86]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.42it/s, d_loss=0.358, e=92, g_loss=3.21] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.257, e=93, g_loss=2.84] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.155, e=94, g_loss=3.12] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.276, e=95, g_loss=2.94] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.40it/s, d_loss=0.405, e=96, g_loss=2.22] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.376, e=97, g_loss=2.82] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.324, e=98, g_loss=2.48] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.277, e=99, g_loss=3.21] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.303, e=100, g_loss=3.07] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.541, e=101, g_loss=2.93] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.402, e=102, g_loss=2.72] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.216, e=103, g_loss=2.67] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.265, e=104, g_loss=2.73] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.259, e=105, g_loss=2.67] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.149, e=106, g_loss=2.76] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.118, e=107, g_loss=2.84] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.338, e=108, g_loss=2.95] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.145, e=109, g_loss=3.12] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.39it/s, d_loss=0.565, e=110, g_loss=2.84] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.39it/s, d_loss=0.305, e=111, g_loss=2.93] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.40it/s, d_loss=0.373, e=112, g_loss=2.97] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.387, e=113, g_loss=2.87] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.269, e=114, g_loss=2.66] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.41it/s, d_loss=0.246, e=115, g_loss=3.1]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:03<00:00,  4.40it/s, d_loss=0.484, e=116, g_loss=2.56] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:02<00:00,  4.40it/s, d_loss=0.267, e=117, g_loss=2.99] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:45<00:00,  3.86it/s, d_loss=0.245, e=118, g_loss=2.78] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:06<00:00,  3.64it/s, d_loss=0.12, e=119, g_loss=2.64]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:38<00:00,  3.94it/s, d_loss=0.271, e=120, g_loss=2.68] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:37<00:00,  3.95it/s, d_loss=0.424, e=121, g_loss=3.02] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:11<00:00,  3.59it/s, d_loss=0.579, e=122, g_loss=2.6]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:16<00:00,  3.54it/s, d_loss=0.359, e=123, g_loss=2.49] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:50<00:00,  3.81it/s, d_loss=0.269, e=124, g_loss=2.68] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:35<00:00,  3.97it/s, d_loss=0.38, e=125, g_loss=2.58]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:32<00:00,  4.01it/s, d_loss=0.244, e=126, g_loss=2.75] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:58<00:00,  3.72it/s, d_loss=0.284, e=127, g_loss=2.77] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:54<00:00,  3.77it/s, d_loss=0.336, e=128, g_loss=2.51] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [05:45<00:00,  3.86it/s, d_loss=0.329, e=129, g_loss=2.8]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:37<00:00,  3.35it/s, d_loss=0.42, e=130, g_loss=2.72]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:27<00:00,  3.44it/s, d_loss=0.339, e=131, g_loss=3.06] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:37<00:00,  3.36it/s, d_loss=0.237, e=132, g_loss=2.63] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:15<00:00,  3.55it/s, d_loss=0.253, e=133, g_loss=2.68] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:41<00:00,  3.32it/s, d_loss=0.301, e=134, g_loss=2.41] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:36<00:00,  3.37it/s, d_loss=0.264, e=135, g_loss=2.57] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [07:00<00:00,  3.18it/s, d_loss=0.22, e=136, g_loss=2.8]   


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [07:12<00:00,  3.09it/s, d_loss=0.53, e=137, g_loss=2.66]  


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [07:35<00:00,  2.93it/s, d_loss=0.367, e=138, g_loss=2.93] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [07:45<00:00,  2.87it/s, d_loss=0.487, e=139, g_loss=2.59] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:48<00:00,  3.26it/s, d_loss=0.281, e=140, g_loss=2.53] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:35<00:00,  3.37it/s, d_loss=0.1, e=141, g_loss=2.78]   


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:37<00:00,  3.35it/s, d_loss=0.243, e=142, g_loss=2.78] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1334/1334 [06:28<00:00,  3.43it/s, d_loss=0.162, e=143, g_loss=2.84] 


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


 55%|█████▍    | 728/1334 [04:07<05:44,  1.76it/s, d_loss=0.238, e=144, g_loss=3.19] 