In [1]:
import os
import sys
import random
import math
import numpy as np
import argparse
import itertools
import glob
import datetime
import time
from PIL import Image
import matplotlib.pyplot as plt
import functools
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.autograd import Variable


from torch.nn.utils.parametrizations import spectral_norm as SpectralNorm


from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

In [2]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image


class ImageDataset(Dataset):
    def __init__(self, root, transforms1=None, transforms2=None, unaligned=False, mode="train"):
        self.transform1 = transforms.Compose(transforms1)
        self.transform2 = transforms.Compose(transforms2)
        self.unaligned = unaligned
    
        path_A = f'{root}/{mode}/A'
        path_B = f'{root}/{mode}/B'
        
        files_A = os.listdir(path_A)
        files_B = os.listdir(path_B)
        
        self.files_A = []        
        for file in files_A:
            # make sure file is an image
            if file.endswith(('.jpg', '.png', 'jpeg')):
                self.files_A.append(path_A+'/'+file)        
                
        self.files_B = []        
        for file in files_B:
            # make sure file is an image
            if file.endswith(('.jpg', '.png', 'jpeg')):
                self.files_B.append(path_B+'/'+file)
        
    def __getitem__(self, index):
        
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform1(image_A)
        item_B = self.transform2(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [3]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [4]:
class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block

        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block.

        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not

        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

In [5]:
class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        """Construct a Resnet-based generator

        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            
            

            model += [nn.Upsample(scale_factor = 2, mode='nearest'),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

In [6]:
class Discriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(Discriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

In [7]:
class DiscriminatorSN(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False):
        super(DiscriminatorSN, self).__init__()
        use_bias = False

        kw = 4
        padw = 1
        sequence = [
            SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)

In [8]:
height,width = 256,256
output_shape = (1, height // 2 ** 4, width // 2 ** 4)
output_shape

(1, 16, 16)

In [9]:
def f(output_size, ksize, stride):
    return (output_size - 1) * stride + ksize

last_layer = f(output_size=1, ksize=4, stride=1)
# Receptive field: 4
fourth_layer = f(output_size=last_layer, ksize=4, stride=1)
# Receptive field: 7
third_layer = f(output_size=fourth_layer, ksize=4, stride=2)
# Receptive field: 16
second_layer = f(output_size=third_layer, ksize=4, stride=2)
# Receptive field: 34
first_layer = f(output_size=second_layer, ksize=4, stride=2)
# Receptive field: 70

print(first_layer)

70


In [10]:
epoch = 120
n_epochs = 125
dataset_name = "cartoon2pixelart"
experiment = "base_distill_sn_9resnet"
batch_size = 1
lr = 0.0002 # Adam learning rate
b1 = 0.5 # beta1. Adam decay of first order momentum of gradient
b2 = 0.999 # beta2. Adam decay of first order momentum of gradient
decay_epoch = 50 # from which epoch to begin lr decay
n_cpu = 8 # number of CPU threads. aka dataloader workers
image_height = 256
image_width = 256
channels = 3
sample_interval = 1000
checkpoint_interval = 30
n_residual_blocks = 9
lambda_cyc = 10.0
lambda_id = 10.0
discriminator_output_shape = (1,30,30)


# Create sample and checkpoint directories
os.makedirs("images/%s/%s" % (dataset_name, experiment), exist_ok=True)
os.makedirs("saved_models/%s/%s" % (dataset_name, experiment), exist_ok=True)

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

cuda = torch.cuda.is_available()

input_shape = (channels, image_height, image_width)

# Initialize generator and discriminator
G_AB = ResnetGenerator(input_shape[0], input_shape[0])
G_BA = ResnetGenerator(input_shape[0], input_shape[0])
D_A = DiscriminatorSN(channels)
D_B = DiscriminatorSN(channels)



if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()
    
print(G_AB)
print(D_A)
    
def weights_init_normal(m):
    """
    Initialize convolution layer weights to $\mathcal{N}(0, 0.2)$
    """
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

if epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/%s/G_AB_%d.pth" % (dataset_name, experiment, epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/%s/G_BA_%d.pth" % (dataset_name, experiment, epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/%s/D_A_%d.pth" % (dataset_name, experiment, epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/%s/D_B_%d.pth" % (dataset_name, experiment, epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))


# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=lambda e: 1.0 - max(0, e - decay_epoch) / (n_epochs - decay_epoch)
)

lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=lambda e: 1.0 - max(0, e - decay_epoch) / (n_epochs - decay_epoch)
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=lambda e: 1.0 - max(0, e - decay_epoch) / (n_epochs - decay_epoch)
)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Image transformations
transforms1 = [
    transforms.Resize(int(image_height), transforms.InterpolationMode.BICUBIC),
  #  transforms.RandomCrop((192, 192)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

transforms2 = [
    transforms.Resize(int(image_height), transforms.InterpolationMode.NEAREST),
 #   transforms.RandomCrop((192, 192)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]


# Training data loader
dataloader = DataLoader(
    ImageDataset("dataset/%s" % dataset_name, transforms1=transforms1,transforms2=transforms2, unaligned=True),
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
)
# Test data loader
val_dataloader = DataLoader(
    ImageDataset("dataset/%s" % dataset_name, transforms1=transforms1,transforms2=transforms2, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)


def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    #G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    fake_B.cpu()
    real_B = Variable(imgs["B"].type(Tensor))
    real_B.cpu()
    real_A.cpu()
    #fake_A = G_BA(real_B)
    #fake_A.cpu()
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    #fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    #image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    image_grid = torch.cat((real_A, fake_B, real_B), 1)
    save_image(image_grid, "images/%s/%s/%s.png" % (dataset_name, experiment ,batches_done), normalize=False)



# ----------
#  Training
# ----------
print(torch.cuda.current_device())

print(torch.cuda.device_count())
# setting device on GPU if available, else CPU


#Additional Info when using cuda
if cuda:
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
print(os.path.exists("dataset/cartoon2pixelart/train/A"))

epoch_losses_G_AB = []
epoch_losses_G_BA = []
epoch_losses_D_A = []
epoch_losses_D_B = []
epoch_losses_D = []
epoch_losses_G = []
epoch_losses_GAN = []
epoch_losses_cycle = []
epoch_losses_identity = []


torch.cuda.empty_cache()
prev_time = time.time()
for epoch in range(epoch, n_epochs):
    losses_G_AB = []
    losses_G_BA = []
    losses_D_A = []
    losses_D_B = []
    losses_D = []
    losses_G = []
    losses_GAN = []
    losses_cycle = []
    losses_identity = []
   
    for i, batch in enumerate(dataloader):
        

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *discriminator_output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *discriminator_output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2



        # Total loss
        loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        
        losses_G_AB.append(loss_GAN_AB.item())
        losses_G_BA.append(loss_GAN_BA.item())
        losses_D_A.append(loss_D_A.item())
        losses_D_B.append(loss_D_B.item())
        losses_D.append(loss_D.item())
        losses_G.append(loss_G.item())
        losses_GAN.append(loss_GAN.item())
        losses_cycle.append(loss_cycle.item())
        losses_identity.append(loss_identity.item())
        

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)
            
    epoch_losses_G_AB.append(sum(losses_G_AB)/len(losses_G_AB))
    epoch_losses_G_BA.append(sum(losses_G_BA)/len(losses_G_BA))
    epoch_losses_D_A.append(sum(losses_D_A)/len(losses_D_A))
    epoch_losses_D_B.append(sum(losses_D_B)/len(losses_D_B))
    
    epoch_losses_D.append(sum(losses_D)/len(losses_D))
    epoch_losses_G.append(sum(losses_G)/len(losses_G))
    epoch_losses_GAN.append(sum(losses_GAN)/len(losses_GAN))
    epoch_losses_cycle.append(sum(losses_cycle)/len(losses_cycle))
    epoch_losses_identity.append(sum(losses_identity)/len(losses_identity))
 

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "saved_models/%s/%s/G_AB_%d.pth" % (dataset_name, experiment, epoch))
        torch.save(G_BA.state_dict(), "saved_models/%s/%s/G_BA_%d.pth" % (dataset_name, experiment, epoch))
        torch.save(D_A.state_dict(), "saved_models/%s/%s/D_A_%d.pth" % (dataset_name, experiment, epoch))
        torch.save(D_B.state_dict(), "saved_models/%s/%s/D_B_%d.pth" % (dataset_name, experiment, epoch))
        #losses = [epoch_losses_D, epoch_losses_G, epoch_losses_GAN, epoch_losses_cycle, epoch_losses_identity, epoch_losses_perceptual, epoch_losses_G_AB, epoch_losses_G_BA, epoch_losses_D_A, epoch_losses_D_B]
        

        #with open(f'cyclegan_perceptual_losses_id_{epoch}', 'wb') as fp:
         #   pickle.dump(losses, fp)

ResnetGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): Reflect

In [11]:
def test_on_image(dataset_):
    
    test_dataloader = DataLoader(
        ImageDataset("dataset/%s" % dataset_, transforms_=transforms_, unaligned=True, mode="test"),
        batch_size=1,
        shuffle=False,
        num_workers=1,
    )
    imgs = next(iter(test_dataloader))
    real_A_ = Variable(imgs["A"].type(Tensor))
    real_B_
    G_AB.eval()
    
    fake_B = G_AB(real_A_)
    # Arange images along x-axis
    real_A = make_grid(real_A_[0], nrow=1, normalize=True)
    fake_B = make_grid(fake_B[0], nrow=1, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B), 2)
    save_image(fake_B, "images/%s/%s.png" % (dataset_name, dataset_), normalize=False)

In [12]:
#test_on_image("test")

In [13]:
torch.cuda.current_device()

0

In [14]:
losses = [epoch_losses_D, epoch_losses_G, epoch_losses_GAN, epoch_losses_cycle, epoch_losses_identity, epoch_losses_G_AB, epoch_losses_G_BA, epoch_losses_D_A, epoch_losses_D_B]

In [15]:
import pickle

with open(f'cyclegan_{experiment}', 'wb') as fp:
    pickle.dump(losses, fp)
    print('Done writing list into a binary file')

Done writing list into a binary file


In [16]:
torch.cuda.get_device_name(0)

'NVIDIA A100-PCIE-40GB'

In [17]:
torch.cuda.device_count()

1

In [18]:
# setting device on GPU if available, else CPU


#Additional Info when using cuda
if cuda:
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

NVIDIA A100-PCIE-40GB
Memory Usage:
Allocated: 0.5 GB
Cached:    5.3 GB
