# Imports

In [None]:
import os
from os.path import join
import random
import math
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import time
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Google Drive Setup

In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [None]:
datadir = "/content/drive/My Drive/CS444/Final_Project"
os.chdir(datadir)
!pwd

/content/drive/My Drive/CS444/Final_Project


# Training Parameters

In [None]:
opt = {
    'large': 0,           # flag for adding extra blocks to generator
    'save_every': 100,    # save models and optimizers during training every X epochs
    'print_every': 15,    # print statistics every X batch per epoch
    'cls_weight': 0.5,    # weight for wrong image/text pairs
    'checkpoint_dir': datadir + '/checkpoints', # where to save models and optimizers
    'captions_file': datadir + '/base_encoded_captions.hdf5', # where dataset captions were stored in DataLoader
    'cache_path': datadir + '/image_cache.pt', # where dataset images were stored in DataLoader
    'fine_size': 64,      # size of cached images saved in DataLoader
    'batch_size': 64,     # number of items per batch
    'txt_size': 384,      # dimensions of text embeddings (based on encoder used)
    'nc': 3,              # image channels (3 for RGB)
    'nt': 256,            # dimensions of text features
    'nz': 100,            # dimensions for noise
    'ngf': 128,           # number of generator filters in first conv layer
    'ndf': 64,            # number of discriminator filters in first conv layer
    'num_workers': 2,     # workers for data loader
    'epochs': 600,        # number of training epochs
    'lr': 0.0002,         # init learning rate for Adam optimizer
    'lr_decay': 0.5,      # learning rate decay factor
    'decay_every': 100,   # learning rate decay frequency
    'beta1': 0.5,         # momentum term of Adam
    'train_amt': 0.75,    # percent of dataset for training (train/test split)
    'display': 1,         # flag whether to display sample every epoch while training (0 = False)
    'noise': 'normal',    # noise type: "uniform" or "normal"
    'init_g': '',         # path to saved generator
    'init_d': '',         # path to saved discriminator
    'init_g_opt': '',     # path to saved generator optimizer
    'init_d_opt': '',     # path to saved discriminator optimizer
    'resume': 0,          # flag whether to resume training from saved models
    'manual_seed': 7,     # manual seed for reproducible results
}

# Initialization Setup

In [None]:
# set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# set seed
random.seed(opt['manual_seed'])
torch.manual_seed(opt['manual_seed'])
if device == 'cuda':
    torch.cuda.manual_seed_all(opt['manual_seed'])

# set default type
torch.set_default_dtype(torch.float32)

# Generator Definition

In [None]:
# NOTE: removed all inplace=True tags due to runtime errors

# reimplementation of ConcatTable & CAddTable block in original generator code
# applies conv branch and elementwise adds the identity
class ConcatAddBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(ConcatAddBlock, self).__init__()
        self.conv_branch = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return x + self.conv_branch(x)

# generator definition
# as close to 1:1 reimplementation as possible
class Generator(nn.Module):
    def __init__(self, opt):
        super(Generator, self).__init__()
        self.nz = opt['nz']
        self.nt = opt['nt']
        self.txt_size = opt['txt_size']
        self.ngf = opt['ngf']
        self.nc = opt['nc']
        self.large = opt['large']

        # transformation for text embedding
        self.fcG = nn.Sequential(
            nn.Linear(self.txt_size, self.nt),
            nn.LeakyReLU(0.2)
        )

        self.deconv1 = nn.ConvTranspose2d(self.nz + self.nt, self.ngf * 8, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(self.ngf * 8)

        # state size: (ngf*8) x 4 x 4
        self.resblock1 = ConcatAddBlock(self.ngf * 8, self.ngf * 2, self.ngf * 8)
        if self.large == 1:
            self.resblock1b = ConcatAddBlock(self.ngf * 8, self.ngf * 2, self.ngf * 8)

        # upsample from 4x4 to 8x8
        self.deconv2 = nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.ngf * 4)

        # state size: (ngf*4) x 8 x 8
        self.resblock2 = ConcatAddBlock(self.ngf * 4, self.ngf, self.ngf * 4)
        if self.large == 1:
            self.resblock2b = ConcatAddBlock(self.ngf * 4, self.ngf, self.ngf * 4)

        # upsample from 8x8 to 16x16
        self.deconv3 = nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.ngf * 2)

        # upsample from 16x16 to 32x32
        self.deconv4 = nn.ConvTranspose2d(self.ngf * 2, self.ngf, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(self.ngf)

        # upsample from 32x32 to 64x64
        self.deconv5 = nn.ConvTranspose2d(self.ngf, self.nc, kernel_size=4, stride=2, padding=1, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, noise, txt):
        batch_size = noise.size(0)
        # process text through fcG
        txt_out = self.fcG(txt)
        # reshape to (batch, nt, 1, 1)
        txt_out = txt_out.view(batch_size, self.nt, 1, 1)

        # concatenate noise and processed text
        # shape: (batch, nz + nt, 1, 1)
        input_vec = torch.cat([noise, txt_out], dim=1)

        x = self.deconv1(input_vec)
        x = self.bn1(x)

        x = self.resblock1(x)
        if self.large == 1:
            x = self.resblock1b(x)
        x = F.relu(x)

        x = self.deconv2(x)
        x = self.bn2(x)

        x = self.resblock2(x)
        if self.large == 1:
            x = self.resblock2b(x)
        x = F.relu(x)

        x = self.deconv3(x)
        x = self.bn3(x)
        x = F.relu(x)

        x = self.deconv4(x)
        x = self.bn4(x)
        x = F.relu(x)

        x = self.deconv5(x)
        output = self.tanh(x)
        return output

# Discriminator Definition

In [None]:
# NOTE: removed all inplace=True tags due to runtime errors

# reimplementation of Replicate in original discriminator code
# replicates a vector spatially to (HxW) feature map
class Replicate2d(nn.Module):
    def __init__(self, H, W):
        super(Replicate2d, self).__init__()
        self.H = H
        self.W = W

    def forward(self, x):
        x = x.reshape(x.size(0), x.size(1), 1, 1)
        return x.repeat(1, 1, self.H, self.W)

# image branch of discriminator
class ImageDiscriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(ImageDiscriminator, self).__init__()
        # input is (nc) x 64 x 64
        self.layer1 = nn.Sequential(
            nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=True),
            nn.LeakyReLU(0.2)
        )
        # state size: (ndf) x 32 x 32
        self.layer2 = nn.Sequential(
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2)
        )
        # state size: (ndf*2) x 16 x 16
        self.layer3 = nn.Sequential(
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2)
        )
        # state size: (ndf*4) x 8 x 8
        self.layer4 = nn.Sequential(
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 8)
        )
        # state size: (ndf*8) x 4 x 4
        self.res_branch = nn.Sequential(
            nn.Conv2d(ndf * 8, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 8, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 8)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        identity = x
        out_branch = self.res_branch(x)
        # elementwise sum (ConcatTable + CAddTable)
        x = identity + out_branch
        x = F.leaky_relu(x, 0.2)
        return x

# text branch of discriminator
class TextDiscriminator(nn.Module):
    def __init__(self, txt_size, nt):
        super(TextDiscriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(txt_size, nt),
            nn.LeakyReLU(0.2)
        )
        self.replicate = Replicate2d(4, 4)

    def forward(self, txt):
        # (batch, nt)
        x = self.fc(txt)
        # (batch, nt, 4, 4)
        x = self.replicate(x)
        return x

# discriminator definition (combines previous discriminators)
class Discriminator(nn.Module):
    def __init__(self, opt):
        super(Discriminator, self).__init__()
        self.nc = opt['nc']
        self.ndf = opt['ndf']
        self.nt = opt['nt']
        self.txt_size = opt['txt_size']

        self.image_net = ImageDiscriminator(self.nc, self.ndf)
        self.text_net = TextDiscriminator(self.txt_size, self.nt)

        # combined features have shape (ndf*8 + nt)
        self.joint_conv1 = nn.Sequential(
            nn.Conv2d(self.ndf * 8 + self.nt, self.ndf * 8, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2)
        )
        self.joint_conv2 = nn.Sequential(
            nn.Conv2d(self.ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, img, txt):
        # process image and text branches separately
        # (batch, ndf*8, 4, 4)
        img_feat = self.image_net(img)
        # (batch, nt, 4, 4)
        txt_feat = self.text_net(txt)
        # concatenate
        # (batch, ndf*8 + nt, 4, 4)
        joint = torch.cat([img_feat, txt_feat], dim=1)
        # (batch, ndf*8, 4, 4)
        x = self.joint_conv1(joint)
        # (batch, 1, 1, 1)
        x = self.joint_conv2(x)
        # reshape to (batch, 1)
        x = x.reshape(x.size(0), -1)
        return x

# Generator & Discriminator Creation

In [None]:
# weight initialization function taken from original code
def weights_init(m):
    classname = m.__class__.__name__
    if 'Conv' in classname:
        if hasattr(m, 'weight') and m.weight is not None:
            nn.init.normal_(m.weight, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif 'BatchNorm' in classname:
        if hasattr(m, 'weight') and m.weight is not None:
            nn.init.normal_(m.weight, 1.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias, 0)

# define generator
if opt['init_g'] == '':
    netG = Generator(opt).to(device)
    netG.apply(weights_init)
else:
    netG = torch.load(opt['init_g'])

# define discriminator
if opt['init_d'] == '':
    netD = Discriminator(opt).to(device)
    netD.apply(weights_init)
else:
    netD = torch.load(opt['init_d'])

# define optimizers
if opt['init_g_opt'] == '':
    optimizerG = torch.optim.Adam(netG.parameters(), lr=opt['lr'], betas=(opt['beta1'], 0.999))
else:
    optimizerG = torch.optim.Adam(netG.parameters(), lr=opt['lr'], betas=(opt['beta1'], 0.999))
    optimizerG.load_state_dict(torch.load(opt['init_g_opt']))

if opt['init_d_opt'] == '':
    optimizerD = torch.optim.Adam(netD.parameters(), lr=opt['lr'], betas=(opt['beta1'], 0.999))
else:
    optimizerD = torch.optim.Adam(netD.parameters(), lr=opt['lr'], betas=(opt['beta1'], 0.999))
    optimizerD.load_state_dict(torch.load(opt['init_d_opt']))

In [None]:
# adding spectral normalization to discriminator
for name, module in netD.named_modules():
    if isinstance(module, nn.Conv2d):
        spectral_norm(module)

# Training Preparation

In [None]:
# define loss criterion
criterion = nn.BCELoss()
criterion = criterion.to(device)

# create tensors for inputs and labels
input_img = torch.empty(opt['batch_size'], opt['nc'], opt['fine_size'], opt['fine_size'])
input_img2 = torch.empty(opt['batch_size'], opt['nc'], opt['fine_size'], opt['fine_size'])
input_txt_emb1 = torch.empty(opt['batch_size'], opt['txt_size'])
noise = torch.empty(opt['batch_size'], opt['nz'], 1, 1)

input_img = input_img.to(device)
input_img2 = input_img2.to(device)
input_txt_emb1 = input_txt_emb1.to(device)
noise = noise.to(device)

# create global variables for error values so they can be referenced in all methods
errD = None
errG = None
errW = None

# labels for training
real_label = 1.0
fake_label = 0.0

# define fake score
fake_score = 0.5

# Dataset Loading

In [None]:
def load_data():
    h = h5py.File(opt['captions_file'])
    flower_captions = {}
    for key, ds in h.items():
        flower_captions[key] = np.array(ds)
    images = [key for key in flower_captions]
    images.sort()
    images_train = int(len(images) * opt['train_amt'])
    # make sure that number of train images splits perfectly into batches
    if images_train % opt['batch_size'] != 0:
        images_train += opt['batch_size'] - (images_train % opt['batch_size'])
    training_images = images[0:images_train]
    random.shuffle(training_images)

    return {
		    'image_list' : training_images,
			  'captions' : flower_captions
		}

loaded_data = load_data()

In [None]:
# transformation to be applied every time an image is retrieved
transform = transforms.Compose([
    transforms.Resize(74),
    transforms.RandomCrop(64),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

class FlowerTextImageDataset(Dataset):
    def __init__(self, cache_path, captions_dict, image_list, transform):
        """
        cache_path: cached file mapping name to raw image data
        captions_dict: mappping name to caption array
        image_list: list of names (keys for both cache_path & captions_dict)
        transform: torchvision transforms to apply each call
        """
        super().__init__()
        self.cache = torch.load(cache_path)
        self.captions = captions_dict
        self.image_list = image_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        # get real image
        fname_real = self.image_list[idx]
        data_real = self.cache[fname_real]
        # convert to uint8 (H, W, C) np for PIL
        arr_real = (data_real * 255).byte().cpu()
        arr_real = arr_real.permute(1, 2, 0).numpy()
        img_real = Image.fromarray(arr_real)
        real_img_tensor = self.transform(img_real)

        # randomly select 1 caption for real image
        caps = self.captions[fname_real]
        cap_idx = random.randrange(len(caps))
        txt_emb = torch.from_numpy(caps[cap_idx]).float()

        # ensure wrong_idx != idx
        wrong_idx = random.randint(0, len(self.image_list) - 2)
        if wrong_idx >= idx:
            wrong_idx += 1
        # get wrong image
        fname_wrong = self.image_list[wrong_idx]
        data_wrong = self.cache[fname_wrong]
        # convert to uint8 (H, W, C) np for PIL
        arr_wrong = (data_wrong * 255).byte().cpu()
        arr_wrong = arr_wrong.permute(1, 2, 0).numpy()
        img_wrong = Image.fromarray(arr_wrong)
        wrong_img_tensor = self.transform(img_wrong)

        return real_img_tensor, wrong_img_tensor, txt_emb

# create dataset
dataset = FlowerTextImageDataset(
    cache_path=opt['cache_path'],
    captions_dict=loaded_data['captions'],
    image_list=loaded_data['image_list'],
    transform=transform
)

# create dataloader
loader = DataLoader(
    dataset,
    batch_size=opt['batch_size'],
    shuffle=True,
    num_workers=opt['num_workers'],
    pin_memory=True
)

# Closures

In [None]:
# generator closure
def fGx():
    # reference global variables used in function
    global fake_score, input_img, input_img2, input_txt_emb1, noise, netD, netG, criterion, optimizerG

    # zero biases in all conv layers for both netD and netG
    for m in netD.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)) and m.bias is not None:
            m.bias.data.zero_()
    for m in netG.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)) and m.bias is not None:
            m.bias.data.zero_()

    optimizerG.zero_grad()

    # generate random noise
    if opt['noise'] == 'uniform':
        noise.uniform_(-1, 1)
    elif opt['noise'] == 'normal':
        noise.normal_(0, 1)

    # generate fake images from netG
    fake = netG(noise, input_txt_emb1)
    # copy into global tensor
    fake.detach()
    input_img = fake.type(torch.float32).to(device)
    # for generator loss, use real label
    label_real = torch.full((input_img.shape[0], 1), real_label, dtype=torch.float32, device=device)

    # compute netD output on the fake images
    output = netD(input_img, input_txt_emb1)
    # update fake score
    cur_score = output.mean().item()
    fake_score = 0.99 * fake_score + 0.01 * cur_score

    errG = criterion(output, label_real)

    errG.backward()

    return errG.item()

In [None]:
# discriminator closure
def fDx():
    # reference global variables used in function
    global fake_score, input_img, input_img2, input_txt_emb1, noise, netD, netG, criterion, optimizerD

    # zero biases in all conv layers for both netD and netG
    for m in netD.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)) and m.bias is not None:
            m.bias.data.zero_()
    for m in netG.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)) and m.bias is not None:
            m.bias.data.zero_()

    optimizerD.zero_grad()

    # train with real examples
    label_real = torch.full((input_img.shape[0], 1), real_label, dtype=torch.float32, device=device)
    output = netD(input_img, input_txt_emb1)
    # label = label.view_as(output)
    errD_real = criterion(output, label_real)

    # train with wrong image/text pairs
    errD_wrong = 0.0
    if opt['cls_weight'] > 0:
        label_fake = torch.full((input_img.shape[0], 1), fake_label, dtype=torch.float32, device=device)
        output_wrong = netD(input_img2, input_txt_emb1)
        errD_wrong = opt['cls_weight'] * criterion(output_wrong, label_fake)

    # train with fake examples
    # generate random noise
    if opt['noise'] == 'uniform':
        noise.uniform_(-1, 1)
    elif opt['noise'] == 'normal':
        noise.normal_(0, 1)

    # generate fake images from netG
    fake = netG(noise, input_txt_emb1)
    # copy into global tensor
    fake.detach()
    input_img = fake.type(torch.float32).to(device)

    label_fake2 = torch.full((input_img.shape[0], 1), fake_label, dtype=torch.float32, device=device)

    # compute netD output on the fake images
    output_fake = netD(input_img, input_txt_emb1)
    # update fake score
    cur_score = output_fake.mean().item()
    fake_score = 0.99 * fake_score + 0.01 * cur_score

    fake_weight = 1 - opt['cls_weight']
    errD_fake = criterion(output_fake, label_fake2) * fake_weight

    # calculate total loss for the discriminator
    errD = errD_real + errD_wrong + errD_fake
    errD.backward()

    return errD.item(), errD_real.item(), errD_wrong.item(), errD_fake.item()

# Training

In [None]:
# calculate number of batches
total_batches = math.floor(len(dataset) / opt['batch_size'])

# create fixed noise and label for generator display every epoch
fixed_z = torch.randn(1, opt['nz'], 1, 1, device=device)
idx = opt['manual_seed'] % len(dataset)
_, _, txt = dataset[idx]
fixed_txt = txt.float().to(device)

# add noise starting from this epoch
start_noise = 20
# for this many epochs
noise_epochs = 50
# linearly decay from this to 0
initial_sigma = 0.3

# determine which epoch to start at
start = 1
if opt['resume'] == 1:
    start = int(opt['init_g'].split('/')[-1].split('_')[0])

for epoch in range(start, opt['epochs'] + 1):
    netG.train()
    netD.train()

    # determine whether to add noise and how much
    if epoch < (start_noise + noise_epochs) and epoch >= start_noise:
        sigma = initial_sigma * max(0.0, ((start_noise + noise_epochs) - epoch) / noise_epochs)
    else:
        sigma = 0

    epoch_start_time = time.time()

    # decay the learning rate at the specified interval
    if epoch % opt['decay_every'] == 0:
        for param_group in optimizerG.param_groups:
            param_group['lr'] *= opt['lr_decay']
        for param_group in optimizerD.param_groups:
            param_group['lr'] *= opt['lr_decay']

    for i, (real_images, wrong_images, captions) in enumerate(loader):
        iter_start_time = time.time()

        input_img = real_images.float().to(device)
        input_img2 = wrong_images.float().to(device)
        input_txt_emb1 = captions.float().to(device)

        # apply one-sided label smoothing
        real_label = round(random.uniform(0.85, 1.00), 2)
        # fake_label = round(random.uniform(0.00, 0.20), 2)

        if sigma > 0:
            input_img += sigma * torch.randn_like(input_img)
            input_img2 += sigma * torch.randn_like(input_img2)

        # discriminator gradients and loss
        errD, errDreal, errW, errDfake = fDx()
        optimizerD.step()

        # generator gradients and loss
        # do X G updates for every discriminator update
        for _ in range(1):
            errG = fGx()
            optimizerG.step()

        # log batch statistics
        if (i % opt['print_every']) == 0:
            iter_time = time.time() - iter_start_time
            current_lr = optimizerG.param_groups[0]['lr']
            print(f"[{epoch}][{i}/{total_batches}] T:{iter_time:.3f} lr:{current_lr:.4g} "
                  f"G:{errG if errG is not None else -1:.3f}  D:{errD if errD is not None else -1:.3f}  "
                  f"Dr:{errDreal if errDreal is not None else -1:.3f}  Df:{errDfake if errDfake is not None else -1:.3f}  "
                  f"W:{errW if errW is not None else -1:.3f}  fs:{fake_score:.2f}")

    # display sample generation
    if opt['display'] == 1:
        netG.eval()
        with torch.no_grad():
            fake = netG(fixed_z, fixed_txt)
        # compute statistics
        mn, mx = fake.min().item(), fake.max().item()
        mean, std = fake.mean().item(), fake.std().item()
        print("")
        print(f"[{epoch}]  generator: min {mn:.4f}, max {mx:.4f}, mean {mean:.4f}, std {std:.4f}")
        print("")
        # convert for plotting
        img = (fake[0].cpu() + 1) * 0.5
        img = img.permute(1, 2, 0).numpy()

        # visualize
        plt.figure(figsize=(3,3))
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"Epoch {epoch}")
        plt.show()
        print("")
        netG.train()

    # save checkpoints at intervals
    if epoch % opt['save_every'] == 0:
        os.makedirs(opt['checkpoint_dir'], exist_ok=True)
        # save model & optimizer state dicts
        torch.save(netG.state_dict(), f"{opt['checkpoint_dir']}/{epoch}_net_G_basic_updated.pth")
        torch.save(netD.state_dict(), f"{opt['checkpoint_dir']}/{epoch}_net_D_basic_updated.pth")
        torch.save(optimizerG.state_dict(), f"{opt['checkpoint_dir']}/{epoch}_opt_G_basic_updated.pth")
        torch.save(optimizerD.state_dict(), f"{opt['checkpoint_dir']}/{epoch}_opt_D_basic_updated.pth")
    epoch_duration = time.time() - epoch_start_time
    print(f"End of epoch {epoch} / {opt['epochs']} \t Time Taken: {epoch_duration:.3f}")
    print("")

Output hidden; open in https://colab.research.google.com to view.