In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
ls

data_vector.py  [0m[01;34mdrive[0m/  images.zip  model.py  [01;34msample_data[0m/


In [None]:
# Upload images folder containing the all the image data to the session storage or we can directly use it from drive as well
!unzip images

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: images/Nashville Warbler/Nashville_Warbler_0041_167534.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0042_167346.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0044_167357.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0048_167071.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0050_167475.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0051_167250.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0053_167403.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0054_167258.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0055_167331.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0056_167123.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0057_167008.jpg  
  inflating: images/Nashville Warbler/Nashville_Warbler_0060_167347.jpg  
  inflating: images/Nashville Warbler/Nashville

In [None]:
!mv images CUB_200/

In [None]:
import os
import argparse
import pickle

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data as data
import torchvision.transforms as transforms

from model import Generator, Discriminator
from data_vector import ReadFromVec

In [None]:
def label_like(label, x):
    assert label == 0 or label == 1
    v = torch.zeros_like(x) if label == 0 else torch.ones_like(x)
    v = v.to(x.device)
    return v

def zeros_like(x):
    return label_like(0, x)

def ones_like(x):
    return label_like(1, x)

In [None]:
if not torch.cuda.is_available():
    print('Cuda is not available on this machine.')
device = torch.device('cpu' if not torch.cuda.is_available() else 'cuda')

In [None]:
#device = torch.device('cpu')

In [None]:
device

device(type='cuda')

In [None]:
get_file = open("drive/MyDrive/new_dict_data", 'rb')
train_data = pickle.load(get_file)

In [None]:
datum = train_data[500]
datum['img']

'images/Dark Eyed Junco/Dark_Eyed_Junco_0048_66981.jpg'

In [None]:
train_data = ReadFromVec("drive/MyDrive/new_dict_data", transforms.Compose([transforms.CenterCrop(128), 
                                                                            transforms.RandomHorizontalFlip(), 
                                                                            transforms.RandomRotation(10), 
                                                                            transforms.ToTensor()]))

In [None]:
train_loader = DataLoader(train_data,
        batch_size = 64,
        shuffle = True,
        num_workers = 1)

In [None]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7fb89b2d3bd0>

In [None]:
G = Generator()
D = Discriminator()
G, D = G.to(device), D.to(device)

In [None]:
# G.load_state_dict(torch.load("drive/MyDrive/birds_GEN.pth"))
# G.eval()
# D.load_state_dict(torch.load("drive/MyDrive/birds_DIS.pth"))
# D.eval()

In [None]:
g_optimizer = torch.optim.Adam(G.parameters(),
                               lr= 0.00005, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(D.parameters(),
                               lr=0.00005, betas=(0.5, 0.999))

g_lr_scheduler = lr_scheduler.StepLR(g_optimizer, 100, 0.5)
d_lr_scheduler = lr_scheduler.StepLR(d_optimizer, 100, 0.5)

In [None]:
# Trained for 10 epochs
for epoch in range(50):
    
    
    avg_D_real_loss = 0
    avg_D_real_c_loss = 0
    avg_D_fake_loss = 0
    avg_G_fake_loss = 0
    avg_G_fake_c_loss = 0
    avg_G_recon_loss = 0
    avg_kld = 0
    for i, (img, txt, len_txt) in enumerate(train_loader):
        img, txt, len_txt = img.to(device), txt.to(device), len_txt.to(device)
        #txt = txt.to(device)
        img = img.mul(2).sub(1)
        # BTC to TBC
        txt = txt.transpose(1, 0)
        # negative text
        txt_m = torch.cat((txt[:, -1, :].unsqueeze(1), txt[:, :-1, :]), 1)
        len_txt_m = torch.cat((len_txt[-1].unsqueeze(0), len_txt[:-1]))
        
        #print(img.size(), txt.size(), len_txt.size())
        #print("Data-fetched")
        
        # UPDATE DISCRIMINATOR
        D.zero_grad()

        # real images
        real_logit, real_c_prob, real_c_prob_n = D(img, txt, len_txt, negative=True)

        real_loss = F.binary_cross_entropy_with_logits(real_logit, ones_like(real_logit))
        avg_D_real_loss += real_loss.item()

        real_c_loss = (F.binary_cross_entropy(real_c_prob, ones_like(real_c_prob)) + \
            F.binary_cross_entropy(real_c_prob_n, zeros_like(real_c_prob_n))) / 2
        avg_D_real_c_loss += real_c_loss.item()

        real_loss = real_loss + 10 * real_c_loss

        real_loss.backward()

        #print("Real Loss")
        
        # synthesized images
        fake, _ = G(img, (txt_m, len_txt_m))
        fake_logit, _ = D(fake.detach(), txt_m, len_txt_m)

        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, zeros_like(fake_logit))
        avg_D_fake_loss += fake_loss.item()

        fake_loss.backward()

        d_optimizer.step()
         
        #print("Fake Loss")
        
        # UPDATE GENERATOR
        G.zero_grad()

        fake, (z_mean, z_log_stddev) = G(img, (txt_m, len_txt_m))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        fake_logit, fake_c_prob = D(fake, txt_m, len_txt_m)
        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, ones_like(fake_logit))
        avg_G_fake_loss += fake_loss.item()
        fake_c_loss = F.binary_cross_entropy(fake_c_prob, ones_like(fake_c_prob))
        avg_G_fake_c_loss += fake_c_loss.item()

        G_loss = fake_loss + 10 * fake_c_loss + 0.5 * kld

        G_loss.backward()
        #print("Gen Loss")
        
        # reconstruction for matching input
        recon, (z_mean, z_log_stddev) = G(img, (txt, len_txt))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        recon_loss = F.l1_loss(recon, img)
        avg_G_recon_loss += recon_loss.item()

        G_loss = 0.2 * recon_loss + 0.5 * kld

        G_loss.backward()

        g_optimizer.step()
        
        #print("Done")
        
        if i % 10 == 0:
            print('Epoch [%03d/%03d], Iter [%03d/%03d], D_real: %.4f, D_real_c: %.4f, D_fake: %.4f, G_fake: %.4f, G_fake_c: %.4f, G_recon: %.4f, KLD: %.4f'
                % (epoch + 1, 50, i + 1, len(train_loader), avg_D_real_loss / (i + 1),
                    avg_D_real_c_loss / (i + 1), avg_D_fake_loss / (i + 1),
                    avg_G_fake_loss / (i + 1), avg_G_fake_c_loss / (i + 1),
                    avg_G_recon_loss / (i + 1), avg_kld / (i + 1)))

    d_lr_scheduler.step()
    g_lr_scheduler.step()

    torch.save(G.state_dict(), "drive/MyDrive/birds_GEN.pth")
    torch.save(D.state_dict(), "drive/MyDrive/birds_DIS.pth")

Epoch [001/050], Iter [001/185], D_real: 1.0742, D_real_c: 0.6937, D_fake: 0.6527, G_fake: 0.9203, G_fake_c: 0.6991, G_recon: 0.4843, KLD: 0.0000
Epoch [001/050], Iter [011/185], D_real: 0.5464, D_real_c: 0.6933, D_fake: 0.5756, G_fake: 1.8368, G_fake_c: 0.6947, G_recon: 0.4663, KLD: 0.0000
Epoch [001/050], Iter [021/185], D_real: 0.5181, D_real_c: 0.6932, D_fake: 0.5501, G_fake: 2.4731, G_fake_c: 0.6939, G_recon: 0.4343, KLD: 0.0000
Epoch [001/050], Iter [031/185], D_real: 0.5165, D_real_c: 0.6932, D_fake: 0.5335, G_fake: 2.5885, G_fake_c: 0.6936, G_recon: 0.4106, KLD: 0.0000
Epoch [001/050], Iter [041/185], D_real: 0.5188, D_real_c: 0.6932, D_fake: 0.5254, G_fake: 2.5294, G_fake_c: 0.6935, G_recon: 0.3929, KLD: 0.0000
Epoch [001/050], Iter [051/185], D_real: 0.5259, D_real_c: 0.6932, D_fake: 0.5351, G_fake: 2.4375, G_fake_c: 0.6935, G_recon: 0.3798, KLD: 0.0000
Epoch [001/050], Iter [061/185], D_real: 0.5606, D_real_c: 0.6932, D_fake: 0.5553, G_fake: 2.3855, G_fake_c: 0.6934, G_recon

KeyboardInterrupt: ignored

In [None]:
for epoch in range(10):
    
    
    avg_D_real_loss = 0
    avg_D_real_c_loss = 0
    avg_D_fake_loss = 0
    avg_G_fake_loss = 0
    avg_G_fake_c_loss = 0
    avg_G_recon_loss = 0
    avg_kld = 0
    for i, (img, txt, len_txt) in enumerate(train_loader):
        img, txt, len_txt = img.to(device), txt.to(device), len_txt.to(device)
        #txt = txt.to(device)
        img = img.mul(2).sub(1)
        # BTC to TBC
        txt = txt.transpose(1, 0)
        # negative text
        txt_m = torch.cat((txt[:, -1, :].unsqueeze(1), txt[:, :-1, :]), 1)
        len_txt_m = torch.cat((len_txt[-1].unsqueeze(0), len_txt[:-1]))
        
        #print(img.size(), txt.size(), len_txt.size())
        #print("Data-fetched")
        
        # UPDATE DISCRIMINATOR
        D.zero_grad()

        # real images
        real_logit, real_c_prob, real_c_prob_n = D(img, txt, len_txt, negative=True)

        real_loss = F.binary_cross_entropy_with_logits(real_logit, ones_like(real_logit))
        avg_D_real_loss += real_loss.item()

        real_c_loss = (F.binary_cross_entropy(real_c_prob, ones_like(real_c_prob)) + \
            F.binary_cross_entropy(real_c_prob_n, zeros_like(real_c_prob_n))) / 2
        avg_D_real_c_loss += real_c_loss.item()

        real_loss = real_loss + 10 * real_c_loss

        real_loss.backward()

        #print("Real Loss")
        
        # synthesized images
        fake, _ = G(img, (txt_m, len_txt_m))
        fake_logit, _ = D(fake.detach(), txt_m, len_txt_m)

        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, zeros_like(fake_logit))
        avg_D_fake_loss += fake_loss.item()

        fake_loss.backward()

        d_optimizer.step()
         
        #print("Fake Loss")
        
        # UPDATE GENERATOR
        G.zero_grad()

        fake, (z_mean, z_log_stddev) = G(img, (txt_m, len_txt_m))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        fake_logit, fake_c_prob = D(fake, txt_m, len_txt_m)
        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, ones_like(fake_logit))
        avg_G_fake_loss += fake_loss.item()
        fake_c_loss = F.binary_cross_entropy(fake_c_prob, ones_like(fake_c_prob))
        avg_G_fake_c_loss += fake_c_loss.item()

        G_loss = fake_loss + 10 * fake_c_loss + 0.5 * kld

        G_loss.backward()
        #print("Gen Loss")
        
        # reconstruction for matching input
        recon, (z_mean, z_log_stddev) = G(img, (txt, len_txt))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        recon_loss = F.l1_loss(recon, img)
        avg_G_recon_loss += recon_loss.item()

        G_loss = 0.2 * recon_loss + 0.5 * kld

        G_loss.backward()

        g_optimizer.step()
        
        #print("Done")
        
        if i % 10 == 0:
            print('Epoch [%03d/%03d], Iter [%03d/%03d], D_real: %.4f, D_real_c: %.4f, D_fake: %.4f, G_fake: %.4f, G_fake_c: %.4f, G_recon: %.4f, KLD: %.4f'
                % (epoch + 1, 50, i + 1, len(train_loader), avg_D_real_loss / (i + 1),
                    avg_D_real_c_loss / (i + 1), avg_D_fake_loss / (i + 1),
                    avg_G_fake_loss / (i + 1), avg_G_fake_c_loss / (i + 1),
                    avg_G_recon_loss / (i + 1), avg_kld / (i + 1)))

    d_lr_scheduler.step()
    g_lr_scheduler.step()
#     img_vis = img.mul(0.5).add(0.5)
#     vis.images(img_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='original'))
#     fake_vis = fake.mul(0.5).add(0.5)
#     vis.images(fake_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='generated'))

    torch.save(G.state_dict(), "drive/MyDrive/birds_GEN.pth")
    torch.save(D.state_dict(), "drive/MyDrive/birds_DIS.pth")

Epoch [001/050], Iter [001/185], D_real: 0.5057, D_real_c: 0.6871, D_fake: 0.3773, G_fake: 1.4122, G_fake_c: 0.5679, G_recon: 0.3074, KLD: 0.0802
Epoch [001/050], Iter [011/185], D_real: 0.5597, D_real_c: 0.6463, D_fake: 0.5562, G_fake: 1.3953, G_fake_c: 0.5454, G_recon: 0.2796, KLD: 0.0926
Epoch [001/050], Iter [021/185], D_real: 0.5639, D_real_c: 0.6453, D_fake: 0.5511, G_fake: 1.3903, G_fake_c: 0.5479, G_recon: 0.2807, KLD: 0.0929
Epoch [001/050], Iter [031/185], D_real: 0.5829, D_real_c: 0.6405, D_fake: 0.5473, G_fake: 1.3934, G_fake_c: 0.5433, G_recon: 0.2812, KLD: 0.0958
Epoch [001/050], Iter [041/185], D_real: 0.5721, D_real_c: 0.6405, D_fake: 0.5578, G_fake: 1.3925, G_fake_c: 0.5378, G_recon: 0.2824, KLD: 0.0960
Epoch [001/050], Iter [051/185], D_real: 0.5757, D_real_c: 0.6444, D_fake: 0.5604, G_fake: 1.3879, G_fake_c: 0.5421, G_recon: 0.2837, KLD: 0.0955
Epoch [001/050], Iter [061/185], D_real: 0.5841, D_real_c: 0.6434, D_fake: 0.5692, G_fake: 1.3877, G_fake_c: 0.5427, G_recon

In [None]:
for epoch in range(10):
    
    
    avg_D_real_loss = 0
    avg_D_real_c_loss = 0
    avg_D_fake_loss = 0
    avg_G_fake_loss = 0
    avg_G_fake_c_loss = 0
    avg_G_recon_loss = 0
    avg_kld = 0
    for i, (img, txt, len_txt) in enumerate(train_loader):
        img, txt, len_txt = img.to(device), txt.to(device), len_txt.to(device)
        #txt = txt.to(device)
        img = img.mul(2).sub(1)
        # BTC to TBC
        txt = txt.transpose(1, 0)
        # negative text
        txt_m = torch.cat((txt[:, -1, :].unsqueeze(1), txt[:, :-1, :]), 1)
        len_txt_m = torch.cat((len_txt[-1].unsqueeze(0), len_txt[:-1]))
        
        #print(img.size(), txt.size(), len_txt.size())
        #print("Data-fetched")
        
        # UPDATE DISCRIMINATOR
        D.zero_grad()

        # real images
        real_logit, real_c_prob, real_c_prob_n = D(img, txt, len_txt, negative=True)

        real_loss = F.binary_cross_entropy_with_logits(real_logit, ones_like(real_logit))
        avg_D_real_loss += real_loss.item()

        real_c_loss = (F.binary_cross_entropy(real_c_prob, ones_like(real_c_prob)) + \
            F.binary_cross_entropy(real_c_prob_n, zeros_like(real_c_prob_n))) / 2
        avg_D_real_c_loss += real_c_loss.item()

        real_loss = real_loss + 10 * real_c_loss

        real_loss.backward()

        #print("Real Loss")
        
        # synthesized images
        fake, _ = G(img, (txt_m, len_txt_m))
        fake_logit, _ = D(fake.detach(), txt_m, len_txt_m)

        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, zeros_like(fake_logit))
        avg_D_fake_loss += fake_loss.item()

        fake_loss.backward()

        d_optimizer.step()
         
        #print("Fake Loss")
        
        # UPDATE GENERATOR
        G.zero_grad()

        fake, (z_mean, z_log_stddev) = G(img, (txt_m, len_txt_m))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        fake_logit, fake_c_prob = D(fake, txt_m, len_txt_m)
        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, ones_like(fake_logit))
        avg_G_fake_loss += fake_loss.item()
        fake_c_loss = F.binary_cross_entropy(fake_c_prob, ones_like(fake_c_prob))
        avg_G_fake_c_loss += fake_c_loss.item()

        G_loss = fake_loss + 10 * fake_c_loss + 0.5 * kld

        G_loss.backward()
        #print("Gen Loss")
        
        # reconstruction for matching input
        recon, (z_mean, z_log_stddev) = G(img, (txt, len_txt))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        recon_loss = F.l1_loss(recon, img)
        avg_G_recon_loss += recon_loss.item()

        G_loss = 0.2 * recon_loss + 0.5 * kld

        G_loss.backward()

        g_optimizer.step()
        
        #print("Done")
        
        if i % 10 == 0:
            print('Epoch [%03d/%03d], Iter [%03d/%03d], D_real: %.4f, D_real_c: %.4f, D_fake: %.4f, G_fake: %.4f, G_fake_c: %.4f, G_recon: %.4f, KLD: %.4f'
                % (epoch + 1, 50, i + 1, len(train_loader), avg_D_real_loss / (i + 1),
                    avg_D_real_c_loss / (i + 1), avg_D_fake_loss / (i + 1),
                    avg_G_fake_loss / (i + 1), avg_G_fake_c_loss / (i + 1),
                    avg_G_recon_loss / (i + 1), avg_kld / (i + 1)))

    d_lr_scheduler.step()
    g_lr_scheduler.step()
#     img_vis = img.mul(0.5).add(0.5)
#     vis.images(img_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='original'))
#     fake_vis = fake.mul(0.5).add(0.5)
#     vis.images(fake_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='generated'))

    torch.save(G.state_dict(), "drive/MyDrive/birds_GEN.pth")
    torch.save(D.state_dict(), "drive/MyDrive/birds_DIS.pth")

Epoch [001/050], Iter [001/185], D_real: 0.3427, D_real_c: 0.6817, D_fake: 0.8350, G_fake: 2.0738, G_fake_c: 0.6482, G_recon: 0.2746, KLD: 0.0930
Epoch [001/050], Iter [011/185], D_real: 0.5076, D_real_c: 0.6383, D_fake: 0.5324, G_fake: 1.5919, G_fake_c: 0.5544, G_recon: 0.2746, KLD: 0.1056
Epoch [001/050], Iter [021/185], D_real: 0.5410, D_real_c: 0.6380, D_fake: 0.5264, G_fake: 1.4741, G_fake_c: 0.5455, G_recon: 0.2648, KLD: 0.0980
Epoch [001/050], Iter [031/185], D_real: 0.5503, D_real_c: 0.6352, D_fake: 0.5382, G_fake: 1.4296, G_fake_c: 0.5378, G_recon: 0.2585, KLD: 0.0940
Epoch [001/050], Iter [041/185], D_real: 0.5628, D_real_c: 0.6312, D_fake: 0.5591, G_fake: 1.4256, G_fake_c: 0.5379, G_recon: 0.2594, KLD: 0.0928
Epoch [001/050], Iter [051/185], D_real: 0.5601, D_real_c: 0.6304, D_fake: 0.5682, G_fake: 1.4506, G_fake_c: 0.5337, G_recon: 0.2601, KLD: 0.0932
Epoch [001/050], Iter [061/185], D_real: 0.5586, D_real_c: 0.6296, D_fake: 0.5617, G_fake: 1.4426, G_fake_c: 0.5353, G_recon

In [None]:
for epoch in range(20):
    
    
    avg_D_real_loss = 0
    avg_D_real_c_loss = 0
    avg_D_fake_loss = 0
    avg_G_fake_loss = 0
    avg_G_fake_c_loss = 0
    avg_G_recon_loss = 0
    avg_kld = 0
    for i, (img, txt, len_txt) in enumerate(train_loader):
        img, txt, len_txt = img.to(device), txt.to(device), len_txt.to(device)
        #txt = txt.to(device)
        img = img.mul(2).sub(1)
        # BTC to TBC
        txt = txt.transpose(1, 0)
        # negative text
        txt_m = torch.cat((txt[:, -1, :].unsqueeze(1), txt[:, :-1, :]), 1)
        len_txt_m = torch.cat((len_txt[-1].unsqueeze(0), len_txt[:-1]))
        
        #print(img.size(), txt.size(), len_txt.size())
        #print("Data-fetched")
        
        # UPDATE DISCRIMINATOR
        D.zero_grad()

        # real images
        real_logit, real_c_prob, real_c_prob_n = D(img, txt, len_txt, negative=True)

        real_loss = F.binary_cross_entropy_with_logits(real_logit, ones_like(real_logit))
        avg_D_real_loss += real_loss.item()

        real_c_loss = (F.binary_cross_entropy(real_c_prob, ones_like(real_c_prob)) + \
            F.binary_cross_entropy(real_c_prob_n, zeros_like(real_c_prob_n))) / 2
        avg_D_real_c_loss += real_c_loss.item()

        real_loss = real_loss + 10 * real_c_loss

        real_loss.backward()

        #print("Real Loss")
        
        # synthesized images
        fake, _ = G(img, (txt_m, len_txt_m))
        fake_logit, _ = D(fake.detach(), txt_m, len_txt_m)

        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, zeros_like(fake_logit))
        avg_D_fake_loss += fake_loss.item()

        fake_loss.backward()

        d_optimizer.step()
         
        #print("Fake Loss")
        
        # UPDATE GENERATOR
        G.zero_grad()

        fake, (z_mean, z_log_stddev) = G(img, (txt_m, len_txt_m))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        fake_logit, fake_c_prob = D(fake, txt_m, len_txt_m)
        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, ones_like(fake_logit))
        avg_G_fake_loss += fake_loss.item()
        fake_c_loss = F.binary_cross_entropy(fake_c_prob, ones_like(fake_c_prob))
        avg_G_fake_c_loss += fake_c_loss.item()

        G_loss = fake_loss + 10 * fake_c_loss + 0.5 * kld

        G_loss.backward()
        #print("Gen Loss")
        
        # reconstruction for matching input
        recon, (z_mean, z_log_stddev) = G(img, (txt, len_txt))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        recon_loss = F.l1_loss(recon, img)
        avg_G_recon_loss += recon_loss.item()

        G_loss = 0.2 * recon_loss + 0.5 * kld

        G_loss.backward()

        g_optimizer.step()
        
        #print("Done")
        
        if i % 10 == 0:
            print('Epoch [%03d/%03d], Iter [%03d/%03d], D_real: %.4f, D_real_c: %.4f, D_fake: %.4f, G_fake: %.4f, G_fake_c: %.4f, G_recon: %.4f, KLD: %.4f'
                % (epoch + 1, 50, i + 1, len(train_loader), avg_D_real_loss / (i + 1),
                    avg_D_real_c_loss / (i + 1), avg_D_fake_loss / (i + 1),
                    avg_G_fake_loss / (i + 1), avg_G_fake_c_loss / (i + 1),
                    avg_G_recon_loss / (i + 1), avg_kld / (i + 1)))

    d_lr_scheduler.step()
    g_lr_scheduler.step()
#     img_vis = img.mul(0.5).add(0.5)
#     vis.images(img_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='original'))
#     fake_vis = fake.mul(0.5).add(0.5)
#     vis.images(fake_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='generated'))

    torch.save(G.state_dict(), "drive/MyDrive/birds_GEN.pth")
    torch.save(D.state_dict(), "drive/MyDrive/birds_DIS.pth")

Epoch [001/050], Iter [001/185], D_real: 0.6175, D_real_c: 0.6193, D_fake: 0.4417, G_fake: 1.0549, G_fake_c: 0.5779, G_recon: 0.2370, KLD: 0.0777
Epoch [001/050], Iter [011/185], D_real: 0.5178, D_real_c: 0.6059, D_fake: 0.4930, G_fake: 1.5139, G_fake_c: 0.5356, G_recon: 0.2425, KLD: 0.1111
Epoch [001/050], Iter [021/185], D_real: 0.5213, D_real_c: 0.6043, D_fake: 0.5304, G_fake: 1.5995, G_fake_c: 0.5397, G_recon: 0.2474, KLD: 0.1114
Epoch [001/050], Iter [031/185], D_real: 0.5300, D_real_c: 0.6058, D_fake: 0.5234, G_fake: 1.5712, G_fake_c: 0.5423, G_recon: 0.2541, KLD: 0.1078
Epoch [001/050], Iter [041/185], D_real: 0.5160, D_real_c: 0.6068, D_fake: 0.5297, G_fake: 1.5824, G_fake_c: 0.5460, G_recon: 0.2577, KLD: 0.1058
Epoch [001/050], Iter [051/185], D_real: 0.5165, D_real_c: 0.6071, D_fake: 0.5165, G_fake: 1.5778, G_fake_c: 0.5384, G_recon: 0.2562, KLD: 0.1067
Epoch [001/050], Iter [061/185], D_real: 0.5065, D_real_c: 0.6069, D_fake: 0.5080, G_fake: 1.5715, G_fake_c: 0.5412, G_recon

In [None]:
for epoch in range(50):
    
    
    avg_D_real_loss = 0
    avg_D_real_c_loss = 0
    avg_D_fake_loss = 0
    avg_G_fake_loss = 0
    avg_G_fake_c_loss = 0
    avg_G_recon_loss = 0
    avg_kld = 0
    for i, (img, txt, len_txt) in enumerate(train_loader):
        img, txt, len_txt = img.to(device), txt.to(device), len_txt.to(device)
        #txt = txt.to(device)
        img = img.mul(2).sub(1)
        # BTC to TBC
        txt = txt.transpose(1, 0)
        # negative text
        txt_m = torch.cat((txt[:, -1, :].unsqueeze(1), txt[:, :-1, :]), 1)
        len_txt_m = torch.cat((len_txt[-1].unsqueeze(0), len_txt[:-1]))
        
        #print(img.size(), txt.size(), len_txt.size())
        #print("Data-fetched")
        
        # UPDATE DISCRIMINATOR
        D.zero_grad()

        # real images
        real_logit, real_c_prob, real_c_prob_n = D(img, txt, len_txt, negative=True)

        real_loss = F.binary_cross_entropy_with_logits(real_logit, ones_like(real_logit))
        avg_D_real_loss += real_loss.item()

        real_c_loss = (F.binary_cross_entropy(real_c_prob, ones_like(real_c_prob)) + \
            F.binary_cross_entropy(real_c_prob_n, zeros_like(real_c_prob_n))) / 2
        avg_D_real_c_loss += real_c_loss.item()

        real_loss = real_loss + 10 * real_c_loss

        real_loss.backward()

        #print("Real Loss")
        
        # synthesized images
        fake, _ = G(img, (txt_m, len_txt_m))
        fake_logit, _ = D(fake.detach(), txt_m, len_txt_m)

        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, zeros_like(fake_logit))
        avg_D_fake_loss += fake_loss.item()

        fake_loss.backward()

        d_optimizer.step()
         
        #print("Fake Loss")
        
        # UPDATE GENERATOR
        G.zero_grad()

        fake, (z_mean, z_log_stddev) = G(img, (txt_m, len_txt_m))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        fake_logit, fake_c_prob = D(fake, txt_m, len_txt_m)
        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, ones_like(fake_logit))
        avg_G_fake_loss += fake_loss.item()
        fake_c_loss = F.binary_cross_entropy(fake_c_prob, ones_like(fake_c_prob))
        avg_G_fake_c_loss += fake_c_loss.item()

        G_loss = fake_loss + 10 * fake_c_loss + 0.5 * kld

        G_loss.backward()
        #print("Gen Loss")
        
        # reconstruction for matching input
        recon, (z_mean, z_log_stddev) = G(img, (txt, len_txt))

        kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1))
        avg_kld += 0.5 * kld.item()

        recon_loss = F.l1_loss(recon, img)
        avg_G_recon_loss += recon_loss.item()

        G_loss = 0.2 * recon_loss + 0.5 * kld

        G_loss.backward()

        g_optimizer.step()
        
        #print("Done")
        
        if i % 10 == 0:
            print('Epoch [%03d/%03d], Iter [%03d/%03d], D_real: %.4f, D_real_c: %.4f, D_fake: %.4f, G_fake: %.4f, G_fake_c: %.4f, G_recon: %.4f, KLD: %.4f'
                % (epoch + 1, 50, i + 1, len(train_loader), avg_D_real_loss / (i + 1),
                    avg_D_real_c_loss / (i + 1), avg_D_fake_loss / (i + 1),
                    avg_G_fake_loss / (i + 1), avg_G_fake_c_loss / (i + 1),
                    avg_G_recon_loss / (i + 1), avg_kld / (i + 1)))

    d_lr_scheduler.step()
    g_lr_scheduler.step()
#     img_vis = img.mul(0.5).add(0.5)
#     vis.images(img_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='original'))
#     fake_vis = fake.mul(0.5).add(0.5)
#     vis.images(fake_vis.cpu().detach().numpy(), nrow=4, opts=dict(title='generated'))

    torch.save(G.state_dict(), "drive/MyDrive/birds_GEN.pth")
    torch.save(D.state_dict(), "drive/MyDrive/birds_DIS.pth")

Epoch [001/050], Iter [001/185], D_real: 0.2096, D_real_c: 0.5932, D_fake: 0.6814, G_fake: 3.0371, G_fake_c: 0.5256, G_recon: 0.2412, KLD: 0.1623
Epoch [001/050], Iter [011/185], D_real: 0.4591, D_real_c: 0.5689, D_fake: 0.4782, G_fake: 1.8631, G_fake_c: 0.4790, G_recon: 0.2723, KLD: 0.1605
Epoch [001/050], Iter [021/185], D_real: 0.4445, D_real_c: 0.5739, D_fake: 0.4544, G_fake: 1.8468, G_fake_c: 0.4813, G_recon: 0.2731, KLD: 0.1613
Epoch [001/050], Iter [031/185], D_real: 0.4815, D_real_c: 0.5729, D_fake: 0.4819, G_fake: 1.8622, G_fake_c: 0.4816, G_recon: 0.2676, KLD: 0.1566
Epoch [001/050], Iter [041/185], D_real: 0.4645, D_real_c: 0.5663, D_fake: 0.4590, G_fake: 1.8072, G_fake_c: 0.4814, G_recon: 0.2626, KLD: 0.1586
Epoch [001/050], Iter [051/185], D_real: 0.4575, D_real_c: 0.5676, D_fake: 0.4468, G_fake: 1.7755, G_fake_c: 0.4788, G_recon: 0.2590, KLD: 0.1575
Epoch [001/050], Iter [061/185], D_real: 0.4470, D_real_c: 0.5714, D_fake: 0.4361, G_fake: 1.7842, G_fake_c: 0.4811, G_recon

Done with Training the network