In [1]:
import os
import pickle

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!pip install scipy==1.1.0

Collecting scipy==1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/a8/0b/f163da98d3a01b3e0ef1cab8dd2123c34aee2bafbb1c5bffa354cc8a1730/scipy-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (31.2MB)
[K     |████████████████████████████████| 31.2MB 190kB/s 
[31mERROR: umap-learn 0.4.6 has requirement scipy>=1.3.1, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: tensorflow 2.3.0 has requirement scipy==1.4.1, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: plotnine 0.6.0 has requirement scipy>=1.2.0, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.[0m
Installing collected packages: scipy
  Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
Successfully installed scipy-1.1.0


In [4]:
import itertools, imageio, torch, random
import time
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torchvision import datasets
from scipy.misc import imresize
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
  
import torch.nn as nn
import torch.nn.functional as F

## Config for CartoonGAN

In [None]:
args = {'name': 'sample_data', 
        'src_data': 'src_data_path', 
        'tgt_data': 'tgt_data_path', 
        'vgg_model': 'pre_trained_VGG19_model_path/vgg19.pth', 
        'in_ngc': 3, 
        'out_ngc': 3, 
        'in_ndc': 3, 
        'out_ndc': 1, 
        'batch_size': 8, 
        'ngf': 64, 
        'ndf': 32, 
        'nb': 8, 
        'input_size': 256, 
        'train_epoch': 30, 
        'pre_train_epoch': 10, 
        'lrD': 0.0002, 
        'lrG': 0.0002, 
        'con_lambda': 10, 
        'beta1': 0.5, 
        'beta2': 0.999, 
        'latest_generator_model': '', 
        'latest_discriminator_model': ''}

## Define and Create Dataloaders

In [None]:
def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=True):
    dset = datasets.ImageFolder(path, transform)
    ind = dset.class_to_idx[subfolder]

    n = 0
    for i in range(dset.__len__()):
        if ind != dset.imgs[n][1]:
            del dset.imgs[n]
            n -= 1

        n += 1

    return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

In [None]:
# data_loader
# input_size is 256x256
src_transform = transforms.Compose([
        transforms.Resize((args['input_size'], args['input_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
tgt_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [None]:
data_path = '/content/drive/MyDrive/archive/ds_data'

In [None]:
train_loader_src = data_load(os.path.join(data_path, 'src_data'), 'train', src_transform, args['batch_size'], shuffle=True, drop_last=True)
train_loader_tgt = data_load(os.path.join(data_path, 'tgt_data'), 'pair', tgt_transform, args['batch_size'], shuffle=True, drop_last=True)
test_loader_src = data_load(os.path.join(data_path, 'src_data'), 'test', src_transform, 1, shuffle=True, drop_last=True)

In [None]:
len(train_loader_src)

379

## Define Generator, Discriminator, and VGG Network(feature extractor)

In [None]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
            
class resnet_block(nn.Module):
    def __init__(self, channel, kernel, stride, padding):
        super(resnet_block, self).__init__()
        self.channel = channel
        self.kernel = kernel
        self.strdie = stride
        self.padding = padding
        self.conv1 = nn.Conv2d(channel, channel, kernel, stride, padding)
        self.conv1_norm = nn.InstanceNorm2d(channel)
        self.conv2 = nn.Conv2d(channel, channel, kernel, stride, padding)
        self.conv2_norm = nn.InstanceNorm2d(channel)

        initialize_weights(self)

    def forward(self, input):
        x = F.relu(self.conv1_norm(self.conv1(input)), True)
        x = self.conv2_norm(self.conv2(x))

        return input + x #Elementwise Sum
 

class generator(nn.Module):
    # initializers
    def __init__(self, in_nc, out_nc, nf=32, nb=6):
        super(generator, self).__init__()
        self.input_nc = in_nc
        self.output_nc = out_nc
        self.nf = nf
        self.nb = nb
        self.down_convs = nn.Sequential(
            nn.Conv2d(in_nc, nf, 7, 1, 3), #k7n64s1
            nn.InstanceNorm2d(nf),
            nn.ReLU(True),
            nn.Conv2d(nf, nf * 2, 3, 2, 1), #k3n128s2
            nn.Conv2d(nf * 2, nf * 2, 3, 1, 1), #k3n128s1
            nn.InstanceNorm2d(nf * 2),
            nn.ReLU(True),
            nn.Conv2d(nf * 2, nf * 4, 3, 2, 1), #k3n256s1
            nn.Conv2d(nf * 4, nf * 4, 3, 1, 1), #k3n256s1
            nn.InstanceNorm2d(nf * 4),
            nn.ReLU(True),
        )

        self.resnet_blocks = []
        for i in range(nb):
            self.resnet_blocks.append(resnet_block(nf * 4, 3, 1, 1))

        self.resnet_blocks = nn.Sequential(*self.resnet_blocks)

        self.up_convs = nn.Sequential(
            nn.ConvTranspose2d(nf * 4, nf * 2, 3, 2, 1, 1), #k3n128s1/2
            nn.Conv2d(nf * 2, nf * 2, 3, 1, 1), #k3n128s1
            nn.InstanceNorm2d(nf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(nf * 2, nf, 3, 2, 1, 1), #k3n64s1/2
            nn.Conv2d(nf, nf, 3, 1, 1), #k3n64s1
            nn.InstanceNorm2d(nf),
            nn.ReLU(True),
            nn.Conv2d(nf, out_nc, 7, 1, 3), #k7n3s1
            nn.Tanh(),
        )

        initialize_weights(self)

    # forward method
    def forward(self, input):
        x = self.down_convs(input)
        x = self.resnet_blocks(x)
        output = self.up_convs(x)

        return output


class discriminator(nn.Module):
    # initializers
    def __init__(self, in_nc, out_nc, nf=32):
        super(discriminator, self).__init__()
        self.input_nc = in_nc
        self.output_nc = out_nc
        self.nf = nf
        self.convs = nn.Sequential(
            nn.Conv2d(in_nc, nf, 3, 1, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, nf * 2, 3, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf * 2, nf * 4, 3, 1, 1),
            nn.InstanceNorm2d(nf * 4),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf * 4, nf * 4, 3, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf * 4, nf * 8, 3, 1, 1),
            nn.InstanceNorm2d(nf * 8),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf * 8, nf * 8, 3, 1, 1),
            nn.InstanceNorm2d(nf * 8),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf * 8, out_nc, 3, 1, 1),
            nn.Sigmoid(),
        )

        initialize_weights(self)

    # forward method
    def forward(self, input):
        # input = torch.cat((input1, input2), 1)
        output = self.convs(input)

        return output


class VGG19(nn.Module):
    def __init__(self, init_weights=None, feature_mode=False, batch_norm=False, num_classes=1000):
        super(VGG19, self).__init__()
        self.cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
        self.init_weights = init_weights
        self.feature_mode = feature_mode
        self.batch_norm = batch_norm
        self.num_clases = num_classes
        self.features = self.make_layers(self.cfg, batch_norm)
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if not init_weights == None:
            self.load_state_dict(torch.load(init_weights))

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        if self.feature_mode:
            module_list = list(self.features.modules())
            for l in module_list[1:27]:                 # conv4_4
                x = l(x)
        if not self.feature_mode:
            x = x.view(x.size(0), -1)
            x = self.classifier(x)

        return x

## Load Models 

In [None]:
Gen = generator(args['in_ngc'], args['out_ngc'], args['ngf'], args['nb'])
Dis = discriminator(args['in_ndc'], args['out_ndc'], args['ndf'])

## Load Pre trained VGG Model

In [None]:
os.listdir(data_path)

['vgg19-dcbb9e9d.pth', 'src_data', 'tgt_data', 'generator_latest.pkl']

In [None]:
vgg_model_path = os.path.join(data_path,'vgg19-dcbb9e9d.pth')

In [None]:
VGG = VGG19(init_weights=vgg_model_path, feature_mode=True)

## Device to GPU

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Gen.to(device)
Dis.to(device)
VGG.to(device)
Gen.train()
Dis.train()
VGG.eval()
print('---------- Networks initialized -------------')
print_network(Gen)
print_network(Dis)
print_network(VGG)
print('-----------------------------------------------')

---------- Networks initialized -------------
generator(
  (down_convs): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): ReLU(inplace=True)
  )
  (resnet_blocks): Sequential(
    (0): resnet_block(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv1_norm): Inst

In [None]:
# optimizers
BCE_loss = nn.BCELoss().to(device)
L1_loss = nn.L1Loss().to(device)

In [None]:
G_optimizer = optim.Adam(Gen.parameters(), lr=args['lrG'], betas=(args['beta1'], args['beta2']))
D_optimizer = optim.Adam(Dis.parameters(), lr=args['lrD'], betas=(args['beta1'], args['beta2']))
G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[args['train_epoch'] // 2, 
                                                                                args['train_epoch'] // 4 * 3], 
                                                                                gamma=0.1)
D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[args['train_epoch'] // 2, 
                                                                                args['train_epoch'] // 4 * 3], gamma=0.1)


### Pre-training Step

In [None]:
results_path = './sample_data'
os.listdir(results_path)

['anscombe.json',
 'README.md',
 'mnist_test.csv',
 'california_housing_train.csv',
 'mnist_train_small.csv',
 'california_housing_test.csv']

In [None]:
# path for saving results
if not os.path.isdir(os.path.join(args['name'] + '_results', 'Reconstruction')):
    os.makedirs(os.path.join(args['name'] + '_results', 'Reconstruction'))

if not os.path.isdir(os.path.join(args['name'] + '_results', 'Transfer')):
    os.makedirs(os.path.join(args['name'] + '_results', 'Transfer'))

In [None]:
pre_train_hist = {}
pre_train_hist['Recon_loss'] = []
pre_train_hist['per_epoch_time'] = []
pre_train_hist['total_time'] = []

print('Pre-training start!')
start_time = time.time()
for epoch in range(args['pre_train_epoch']):
    epoch_start_time = time.time()
    Recon_losses = []

    for x, _ in train_loader_src:
        x = x.to(device)

        # train generator G
        G_optimizer.zero_grad()

        x_feature = VGG((x + 1) / 2)
        G_ = Gen(x)
        G_feature = VGG((G_ + 1) / 2)

        Recon_loss = 10 * L1_loss(G_feature, x_feature.detach())
        Recon_losses.append(Recon_loss.item())
        pre_train_hist['Recon_loss'].append(Recon_loss.item())

        Recon_loss.backward()
        G_optimizer.step()

    per_epoch_time = time.time() - epoch_start_time
    pre_train_hist['per_epoch_time'].append(per_epoch_time)
    print('[%d/%d] - time: %.2f, Recon loss: %.3f' % ((epoch + 1), 
                                                      args['pre_train_epoch'], 
                                                      per_epoch_time, 
                                                      torch.mean(torch.FloatTensor(Recon_losses))))
    
    torch.save(Gen.state_dict(), os.path.join(args['name'] + '_results', 'generator_latest.pkl'))
            
total_time = time.time() - start_time
# pre_train_hist['total_time'].append(total_time)
# with open(os.path.join(args['name'] + '_results',  'pre_train_hist.pkl'), 'wb') as f:
#     pickle.dump(pre_train_hist, f)

# with torch.no_grad():
#     Gen.eval()
#     for n, (x, _) in enumerate(train_loader_src):
#         x = x.to(device)
#         G_recon = Gen(x)
#         result = torch.cat((x[0], G_recon[0]), 2)
#         path = os.path.join(args['name'] + '_results', 'Reconstruction', args['name'] + '_train_recon_' + str(n + 1) + '.png')
#         plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
#         if n == 4:
#             break

#     for n, (x, _) in enumerate(test_loader_src):
#         x = x.to(device)
#         G_recon = Gen(x)
#         result = torch.cat((x[0], G_recon[0]), 2)
#         path = os.path.join(args['name'] + '_results', 'Reconstruction', args['name'] + '_test_recon_' + str(n + 1) + '.png')
#         plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
#         if n == 4:
#             break


Pre-training start!


  "Palette images with Transparency expressed in bytes should be "


[1/10] - time: 1541.07, Recon loss: 18.384
[2/10] - time: 173.42, Recon loss: 8.683
[3/10] - time: 171.74, Recon loss: 6.747
[4/10] - time: 171.90, Recon loss: 5.731
[5/10] - time: 171.60, Recon loss: 5.082
[6/10] - time: 172.22, Recon loss: 4.652
[7/10] - time: 172.08, Recon loss: 4.350
[8/10] - time: 171.45, Recon loss: 4.106
[9/10] - time: 171.35, Recon loss: 3.935
[10/10] - time: 171.46, Recon loss: 3.707


In [None]:
torch.save(Gen.state_dict(), os.path.join(data_path, 'generator_latest.pkl'))

In [None]:
pre_train_hist['total_time'].append(total_time)
with open(os.path.join(args['name'] + '_results',  'pre_train_hist.pkl'), 'wb') as f:
    pickle.dump(pre_train_hist, f)

with torch.no_grad():
    Gen.eval()
    for n, (x, _) in enumerate(train_loader_src):
        x = x.to(device)
        G_recon = Gen(x)
        result = torch.cat((x[0], G_recon[0]), 2)
        path = os.path.join(args['name'] + '_results', 'Reconstruction', args['name'] + '_train_recon_' + str(n + 1) + '.png')
        plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
        if n == 4:
            break

    for n, (x, _) in enumerate(test_loader_src):
        x = x.to(device)
        G_recon = Gen(x)
        result = torch.cat((x[0], G_recon[0]), 2)
        path = os.path.join(args['name'] + '_results', 'Reconstruction', args['name'] + '_test_recon_' + str(n + 1) + '.png')
        plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
        if n == 4:
            break

## Training Generator and Discriminator together

In [None]:
train_hist = {}
train_hist['Disc_loss'] = []
train_hist['Gen_loss'] = []
train_hist['Con_loss'] = []
train_hist['per_epoch_time'] = []
train_hist['total_time'] = []
print('training start!')
start_time = time.time()
real = torch.ones(args['batch_size'], 1, args['input_size'] // 4, args['input_size'] // 4).to(device)
fake = torch.zeros(args['batch_size'], 1, args['input_size'] // 4, args['input_size'] // 4).to(device)
for epoch in range(args['train_epoch']):
    epoch_start_time = time.time()
    Gen.train()
    G_scheduler.step()
    D_scheduler.step()
    Disc_losses = []
    Gen_losses = []
    Con_losses = []
    for (x, _), (y, _) in zip(train_loader_src, train_loader_tgt):
        e = y[:, :, :, args['input_size']:]
        y = y[:, :, :, :args['input_size']]
        x, y, e = x.to(device), y.to(device), e.to(device)

        # train D
        D_optimizer.zero_grad()

        D_real = Dis(y)
        D_real_loss = BCE_loss(D_real, real)

        G_ = Gen(x)
        D_fake = Dis(G_)
        D_fake_loss = BCE_loss(D_fake, fake)

        D_edge = Dis(e)
        D_edge_loss = BCE_loss(D_edge, fake)

        Disc_loss = D_real_loss + D_fake_loss + D_edge_loss
        Disc_losses.append(Disc_loss.item())
        train_hist['Disc_loss'].append(Disc_loss.item())

        Disc_loss.backward()
        D_optimizer.step()

        # train G
        G_optimizer.zero_grad()

        G_ = Gen(x)
        D_fake = Dis(G_)
        D_fake_loss = BCE_loss(D_fake, real)

        x_feature = VGG((x + 1) / 2)
        G_feature = VGG((G_ + 1) / 2)
        Con_loss = args['con_lambda'] * L1_loss(G_feature, x_feature.detach())

        Gen_loss = D_fake_loss + Con_loss
        Gen_losses.append(D_fake_loss.item())
        train_hist['Gen_loss'].append(D_fake_loss.item())
        Con_losses.append(Con_loss.item())
        train_hist['Con_loss'].append(Con_loss.item())

        Gen_loss.backward()
        G_optimizer.step()


    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    print(
    '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((epoch + 1), 
        args['train_epoch'], per_epoch_time, torch.mean(torch.FloatTensor(Disc_losses)),
        torch.mean(torch.FloatTensor(Gen_losses)), torch.mean(torch.FloatTensor(Con_losses))))

    if epoch % 2 == 1 or epoch == args['train_epoch'] - 1:
        with torch.no_grad():
            Gen.eval()
            for n, (x, _) in enumerate(train_loader_src):
                x = x.to(device)
                G_recon = Gen(x)
                result = torch.cat((x[0], G_recon[0]), 2)
                path = os.path.join(args['name'] + '_results', 'Transfer', str(epoch+1) + '_epoch_' + args['name'] + '_train_' + str(n + 1) + '.png')
                plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                if n == 4:
                    break

            for n, (x, _) in enumerate(test_loader_src):
                x = x.to(device)
                G_recon = Gen(x)
                result = torch.cat((x[0], G_recon[0]), 2)
                path = os.path.join(args['name'] + '_results', 'Transfer', str(epoch+1) + '_epoch_' + args['name'] + '_test_' + str(n + 1) + '.png')
                plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                if n == 4:
                    break

            torch.save(Gen.state_dict(), os.path.join(args['name'] + '_results', 'generator_latest.pkl'))
            torch.save(Dis.state_dict(), os.path.join(args['name'] + '_results', 'discriminator_latest.pkl'))
            torch.save(Gen.state_dict(), os.path.join(data_path, 'generator_latest.pkl'))
            torch.save(Dis.state_dict(), os.path.join(data_path, 'discriminator_latest.pkl'))

training start!


  "Palette images with Transparency expressed in bytes should be "


[1/100] - time: 1592.92, Disc loss: 1.745, Gen loss: 1.671, Con loss: 3.734
[2/100] - time: 294.88, Disc loss: 1.564, Gen loss: 2.366, Con loss: 3.882
[3/100] - time: 291.91, Disc loss: 1.461, Gen loss: 2.603, Con loss: 3.939
[4/100] - time: 290.91, Disc loss: 1.234, Gen loss: 2.735, Con loss: 4.104
[5/100] - time: 291.80, Disc loss: 0.769, Gen loss: 3.238, Con loss: 4.944
[6/100] - time: 291.47, Disc loss: 0.628, Gen loss: 3.612, Con loss: 5.821
[7/100] - time: 290.61, Disc loss: 0.565, Gen loss: 3.768, Con loss: 6.510
[8/100] - time: 289.23, Disc loss: 0.513, Gen loss: 3.869, Con loss: 7.085
[9/100] - time: 290.33, Disc loss: 0.504, Gen loss: 3.852, Con loss: 7.468
[10/100] - time: 291.33, Disc loss: 0.569, Gen loss: 3.517, Con loss: 7.588
[11/100] - time: 291.56, Disc loss: 0.494, Gen loss: 3.541, Con loss: 7.999
[12/100] - time: 291.56, Disc loss: 0.493, Gen loss: 3.565, Con loss: 8.272
[13/100] - time: 291.39, Disc loss: 0.707, Gen loss: 3.278, Con loss: 7.617
[14/100] - time: 291

In [None]:
total_time = time.time() - start_time
train_hist['total_time'].append(total_time)

In [None]:
print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_time'])), args['train_epoch'], total_time))
print("Training finish!... save training results")

torch.save(Gen.state_dict(), os.path.join(args['name'] + '_results',  'generator_param.pkl'))
torch.save(Dis.state_dict(), os.path.join(args['name'] + '_results',  'discriminator_param.pkl'))
with open(os.path.join(args['name'] + '_results',  'train_hist.pkl'), 'wb') as f:
    pickle.dump(train_hist, f)

Avg one epoch time: 301.27, total 100 epochs time: 30302.76
Training finish!... save training results


In [None]:
!cp -r './sample_data_results' '/content/drive/MyDrive/archive/ds_data'