# Pix2Pix GAN implementation

[Original video](https://youtu.be/SuddDSqGRzg)

[Paper walkthrough video](https://youtu.be/9SGs4Nm0VR4)

[Pix2Pix paper](https://arxiv.org/abs/1611.07004)

[Datasets](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/)

## Get dataset. Import libraries

In [None]:
# # Get pretrained models

# from google.colab import drive
# drive.mount('/content/gdrive')

# !cp -rf '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial/gen.pth.tar' .
# !cp -rf '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial/disc.pth.tar' .

# !ls -hal gen.pth.tar
# !ls -hal disc.pth.tar

In [None]:
# Google CoLab has old version of albumentations library. Update it.
# Maybe after update restart the runtime.
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir

Collecting git+https://github.com/albu/albumentations
  Cloning https://github.com/albu/albumentations to /tmp/pip-req-build-y_zv4_d9
  Running command git clone -q https://github.com/albu/albumentations /tmp/pip-req-build-y_zv4_d9
Collecting imgaug>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948kB)
[K     |████████████████████████████████| 952kB 19.5MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.5.2-cp37-none-any.whl size=88144 sha256=12735548a4bd7ea0157b827744152275b75fc38f2f1a56bba7be9b45f320cba2
  Stored in directory: /tmp/pip-ephem-wheel-cache-5114prfv/wheels/45/8b/e4/2837bbcf517d00732b8e394f8646f22b8723ac00993230188b
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Found exi

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import multiprocessing
import albumentations as A

from PIL import Image
from torchvision.utils import save_image
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm  # for the progressbar

In [None]:
# Get maps dataset
!wget 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz'

--2021-03-18 08:05:55--  http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz
Resolving efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)... 128.32.244.190
Connecting to efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)|128.32.244.190|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 250242400 (239M) [application/x-gzip]
Saving to: ‘maps.tar.gz’


2021-03-18 08:07:09 (3.27 MB/s) - ‘maps.tar.gz’ saved [250242400/250242400]



In [None]:
# Extract data
import zipfile
import tarfile

def extract(fname):
    if fname.endswith(".tar.gz") or fname.endswith('.tgz'):
        ref = tarfile.open(fname, mode='r:gz')
    elif fname.endswith('.tar'):
        ref = tarfile.open(fname, mode='r:')
    elif fname.endswith('.tar.bz2') or fname.endswith('.tbz'):
        ref = tarfile.open(fname, mode='r:bz2')
    elif fname.endswith('.zip'):
        ref = zipfile.ZipFile(fname, mode='r')

    ref.extractall()
    ref.close()

extract('maps.tar.gz')

## Discriminator model

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride,
                      padding_mode='reflect', bias=False),
            
            # Do not normalize across the batches. Normalize only across the layer (instance).
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm
            
            # nn.InstanceNorm2d has better results. No artifacts
            # nn.BatchNorm2d(out_channels),
            
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        # Input 286x286. Output: 30x30
        super().__init__()
        self.initial = nn.Sequential(
            # x - satellite image, y - transformed real or generated fake image
            # x and y are concatenated along the channels
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2,
                      padding=1, padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            stride = 1 if feature == features[-1] else 2
            layers.append(CNNBlock(in_channels, feature, stride=stride))
            in_channels = feature

        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1,
                                padding=1, padding_mode='reflect'))

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)  # concatenate along channels
        x = self.initial(x)
        return self.model(x)


def test():
    x = torch.randn((8, 3, 286, 286))
    y = torch.randn((8, 3, 286, 286))
    model = Discriminator()
    predictions = model(x, y)
    print(predictions.shape)
    assert predictions.shape == (8, 1, 30, 30)

    x = torch.randn((8, 3, 256, 256))
    y = torch.randn((8, 3, 256, 256))
    model = Discriminator()
    predictions = model(x, y)
    print(predictions.shape)

    print('Test - OK')

test()

torch.Size([8, 1, 30, 30])
torch.Size([8, 1, 26, 26])
Test - OK


## Generator model

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
        super().__init__()

        if down:
            layer = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2,
                                  padding=1, padding_mode='reflect', bias=False)
        else:  # cannot use padding_mode='reflect' on the ConvTranspose2d layer
            layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
                                           stride=2, padding=1, bias=False)

        self.conv = nn.Sequential(
            layer,
            
            # Do not normalize across the batches. Normalize only across the layer (instance).
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm
            
            # nn.InstanceNorm2d has better results. No artifacts
            # nn.BatchNorm2d(out_channels),
            
            nn.ReLU() if act == 'relu' else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        # Input: 256
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=4, stride=2, padding=1,
                      padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )  # 128

        self.down1 = Block(features,   features*2, down=True, act='leaky', use_dropout=False)  # 64x64
        self.down2 = Block(features*2, features*4, down=True, act='leaky', use_dropout=False)  # 32x32
        self.down3 = Block(features*4, features*8, down=True, act='leaky', use_dropout=False)  # 16x16
        self.down4 = Block(features*8, features*8, down=True, act='leaky', use_dropout=False)  # 8x8
        self.down5 = Block(features*8, features*8, down=True, act='leaky', use_dropout=False)  # 4x4
        self.down6 = Block(features*8, features*8, down=True, act='leaky', use_dropout=False)  # 2x2
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 1, 1, padding_mode='reflect'),
            nn.ReLU(),
        )  # 1x1
        self.up1 = Block(features*8,   features*8, down=False, act='relu', use_dropout=True)   # 2x2
        self.up2 = Block(features*8*2, features*8, down=False, act='relu', use_dropout=True)   # 4x4
        self.up3 = Block(features*8*2, features*8, down=False, act='relu', use_dropout=True)   # 8x8
        self.up4 = Block(features*8*2, features*8, down=False, act='relu', use_dropout=False)  # 16x16
        self.up5 = Block(features*8*2, features*4, down=False, act='relu', use_dropout=False)  # 32x32
        self.up6 = Block(features*4*2, features*2, down=False, act='relu', use_dropout=False)  # 64x64
        self.up7 = Block(features*2*2, features,   down=False, act='relu', use_dropout=False)  # 128x128
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )  # 256x256

    def forward(self, x):  # U-Net shape-like structure
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], dim=1))
        up3 = self.up3(torch.cat([up2, d6], dim=1))
        up4 = self.up4(torch.cat([up3, d5], dim=1))
        up5 = self.up5(torch.cat([up4, d4], dim=1))
        up6 = self.up6(torch.cat([up5, d3], dim=1))
        up7 = self.up7(torch.cat([up6, d2], dim=1))
        return self.final_up(torch.cat([up7, d1], dim=1))


def test():
    x = torch.randn((8, 3, 256, 256))
    model = Generator()
    predictions = model(x)
    print(predictions.shape)
    assert predictions.shape == (8, 3, 256, 256)
    print('Test - OK')

test()

torch.Size([8, 3, 256, 256])
Test - OK


## Configuration

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = multiprocessing.cpu_count()
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
NUM_EPOCHS = 200
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_DISC = 'disc.pth.tar'
CHECKPOINT_GEN = 'gen.pth.tar'

both_transform = A.Compose([A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),
                            A.Flip(p=0.5),],
                           additional_targets={'image0': 'image'})

transform = A.Compose([
    A.ColorJitter(p=0.1),
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
    ToTensorV2(),
])

In [None]:
def save_example(gen, test_loader, epoch, folder):
    x, y = next(iter(test_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    os.makedirs(folder, exist_ok=True)

    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization
        save_image(y_fake, folder + f'/y_gen_{epoch}.jpg')
        if epoch == 0:
            save_image(x*0.5+0.5, folder + f'/_input_.jpg')
            save_image(y*0.5+0.5, folder + f'/_label_.jpg')
    gen.train()


def save_checkpoint(model, optimizer, filename):
    print('=> Saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # Replace old learning rate from the saved model
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

## Dataset

In [None]:
class MapDataset(Dataset):
    def __init__(self, root):
        super().__init__()
        self.root = root
        self.list_files = os.listdir(self.root)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        filepath = os.path.join(self.root, self.list_files[index])
        image = np.array(Image.open(filepath))

        input_image = image[:, :600, :]
        target_image = image[:, 600:, :]
        
        augmentations = both_transform(image=input_image, image0=target_image)
        input_image, target_image = augmentations['image'], augmentations['image0']
        
        input_image = transform(image=input_image)['image']
        target_image = transform(image=target_image)['image']
        
        return input_image, target_image

## Train

In [None]:
def train(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce_loss, d_scaler, g_scaler):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x, y = x.to(DEVICE), y.to(DEVICE)
        
        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            d_real = disc(x, y)
            d_fake = disc(x, y_fake.detach())
            d_real_loss = bce_loss(d_real, torch.ones_like(d_real))
            d_fake_loss = bce_loss(d_fake, torch.zeros_like(d_fake))
            d_loss = (d_real_loss + d_fake_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(d_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator
        with torch.cuda.amp.autocast():
            d_fake = disc(x, y_fake)
            g_fake_loss = bce_loss(d_fake, torch.ones_like(d_fake))
            l1 = l1_loss(y_fake, y) * L1_LAMBDA
            g_loss = g_fake_loss + l1

        opt_gen.zero_grad()
        d_scaler.scale(g_loss).backward()
        d_scaler.step(opt_gen)
        d_scaler.update()


def main():
    disc = Discriminator().to(DEVICE)
    gen = Generator().to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    bce_loss = nn.BCEWithLogitsLoss()
    l1_loss = nn.L1Loss()

    if LOAD_MODEL and os.path.exists(CHECKPOINT_DISC) and os.path.exists(CHECKPOINT_GEN):
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)

    train_dataset = MapDataset(root='./maps/train')
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    test_dataset = MapDataset(root='./maps/val')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

    d_scaler = torch.cuda.amp.GradScaler()
    g_scaler = torch.cuda.amp.GradScaler()

    disc.train()
    gen.train()

    for epoch in range(NUM_EPOCHS+1):
        train(disc, gen, train_loader, opt_disc, opt_gen, l1_loss, bce_loss, d_scaler, g_scaler)

        if SAVE_MODEL and epoch % 10 == 0:
            print(f'epoch: {epoch}')
            save_checkpoint(disc, opt_disc, CHECKPOINT_DISC)
            save_checkpoint(gen, opt_gen, CHECKPOINT_GEN)
        
        save_example(gen, test_loader, epoch, folder='evaluation')


if __name__ == '__main__':
    main()

HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 0
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 10
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 20
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 30
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 40
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 50
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 60
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 70
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 80
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 90
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 100
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 110
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 120
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 130
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 140
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 150
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 160
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 170
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 180
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 190
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


epoch: 200
=> Saving checkpoint
=> Saving checkpoint


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

!zip -qr evaluation.zip evaluation/

!cp -rf evaluation.zip   '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp -rf $CHECKPOINT_DISC '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp -rf $CHECKPOINT_GEN  '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
