In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
!pip install tensorboardX
import tensorflow as tf
tf.test.gpu_device_name()

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive
Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/5c/76/89dd44458eb976347e5a6e75eb79fecf8facd46c1ce259bad54e0044ea35/tensorboardX-1.6-py2.py3-none-any.whl (129kB)
[K    100% |████████████████████████████████| 133kB 4.7MB/s 
Installing collected packages: tensorboardX
Successfully installed tensorboardX-1.6


'/device:GPU:0'

In [None]:
##########################################################
###DATASET CLASS
##########################################################

import torch.utils.data as data
from PIL import Image
import PIL
import os
import os.path
import pickle
import random
import numpy as np
import pandas as pd

class Dataset(data.Dataset):
    def __init__(self, data_dir, split = 'train', embedding_type = 'cnn-rnn',
                 imsize = 256, transform = None, target_transform = None):
        self.transform = transform
        self.target_transform = target_transform
        self.imsize = imsize
        self.data = []
        self.data_dir = data_dir
        self.bbox = self.load_bbox()
        split_dir = os.path.join(data_dir, split)

        self.filenames = self.load_filenames(split_dir)
        self.embeddings = self.load_embedding(split_dir, embedding_type)
        self.class_id = self.load_class_id(split_dir, len(self.filenames))
    
    def get_img(self, img_path, bbox):
        img = Image.open(img_path).convert('RGB')
        width, height = img.size
        if bbox is not None:
            R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
            center_x = int((2 * bbox[0] + bbox[2]) / 2)
            center_y = int((2 * bbox[1] + bbox[3]) / 2)
            y1 = np.maximum(0, center_y - R)
            y2 = np.minimum(height, center_y + R)
            x1 = np.maximum(0, center_x - R)
            x2 = np.minimum(width, center_x + R)
            img = img.crop([x1, y1, x2, y2])
        load_size = int(self.imsize * 76 / 64)
        img = img.resize((load_size, load_size), PIL.Image.BILINEAR)
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def load_bbox(self):
        data_dir = self.data_dir
        bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
        df_bounding_boxes = pd.read_csv(bbox_path,
                                        delim_whitespace=True,
                                        header=None).astype(int)
        
        filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')
        df_filenames = \
            pd.read_csv(filepath, delim_whitespace=True, header=None)
        filenames = df_filenames[1].tolist()
        #print('Total filenames: ', len(filenames), filenames[0])
        
        filename_bbox = {img_file[:-4]: [] for img_file in filenames}
        numImgs = len(filenames)
        for i in range(0, numImgs):
            # bbox = [x-left, y-top, width, height]
            bbox = df_bounding_boxes.iloc[i][1:].tolist()

            key = filenames[i][:-4]
            filename_bbox[key] = bbox
        
        return filename_bbox 
    
    def load_all_captions(self):
        caption_dict = {}
        for key in self.filenames:
            caption_name = '%s/text/%s.txt' % (self.data_dir, key)
            captions = self.load_captions(caption_name)
            caption_dict[key] = captions
        return caption_dict

    def load_captions(self, caption_name):
        cap_path = caption_name
        with open(cap_path, "r") as f:
            captions = f.read().decode('utf8').split('\n')
        captions = [cap.replace("\ufffd\ufffd", " ")
                    for cap in captions if len(cap) > 0]
        return captions

    def load_embedding(self, data_dir, embedding_type):
        if embedding_type == 'cnn-rnn':
            embedding_filename = '/char-CNN-RNN-embeddings.pickle'
        elif embedding_type == 'cnn-gru':
            embedding_filename = '/char-CNN-GRU-embeddings.pickle'
        elif embedding_type == 'skip-thought':
            embedding_filename = '/skip-thought-embeddings.pickle'

        with open(data_dir + embedding_filename, 'rb') as f:
            embeddings = pickle.load(f, encoding = 'latin1')
            embeddings = np.array(embeddings)
            # embedding_shape = [embeddings.shape[-1]]
            #print('embeddings: ', embeddings.shape)
        return embeddings

    def load_class_id(self, data_dir, total_num):
        if os.path.isfile(data_dir + '/class_info.pickle'):
            with open(data_dir + '/class_info.pickle', 'rb') as f:
                class_id = pickle.load(f, encoding = 'latin1')
        else:
            class_id = np.arange(total_num)
        return class_id

    def load_filenames(self, data_dir):
        filepath = os.path.join(data_dir, 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        #print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
        return filenames
    
    def __getitem__(self, index):
        key = self.filenames[index]
        # cls_id = self.class_id[index]
        #
        if self.bbox is not None:
            bbox = self.bbox[key]
            data_dir = '%s/CUB_200_2011' % self.data_dir
        else:
            bbox = None
            data_dir = self.data_dir

        # captions = self.captions[key]
        embeddings = self.embeddings[index, :, :]
        img_name = '%s/images/%s.jpg' % (data_dir, key)
        img = self.get_img(img_name, bbox)

        embedding_ix = random.randint(0, embeddings.shape[0]-1)
        embedding = embeddings[embedding_ix, :]
        if self.target_transform is not None:
            embedding = self.target_transform(embedding)
        return img, embedding

    def __len__(self):
        return len(self.filenames)

In [None]:
##########################################################
###SELF ATTENTION CLASS
##########################################################

import torch
import torch.nn as nn

class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N)
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        return out


In [None]:
##########################################################
###MODEL CLASS
##########################################################

import torch
import torch.nn as nn
import torch.nn.parallel
#from self_att import self_att
#import Self_Attn
from torch.autograd import Variable
from torch.nn.utils import spectral_norm

def conv3x3(in_shape, out_shape, stride=1):
    return nn.Conv2d(in_shape, out_shape, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv5x5(in_shape, out_shape, stride=1):
    return nn.Conv2d(in_shape, out_shape, kernel_size=5, stride=stride,
                     padding=2, bias=False)

def upResolution(in_shape, out_shape):
    val = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            spectral_norm(conv5x5(in_shape, out_shape)),
            nn.BatchNorm2d(out_shape),
            nn.ReLU(inplace = True)
            )
    return val
  
class ResBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            spectral_norm(conv5x5(in_channels, in_channels)),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(True),
            spectral_norm(conv5x5(in_channels, in_channels)),
            nn.BatchNorm2d(in_channels))
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        out = self.relu(out)
        return out


class generator1(nn.Module):
    def __init__(self):
        super(generator1, self).__init__()
        self.in_dim = 128 *8
        self.cond_dim = 128
        self.noise_dim = 100
        self.doSeq()
        #self.attn = Self_Attn(8, 'relu')

    def doSeq(self):
        inputs = self.noise_dim + self.cond_dim
        #print(inputs)
        dim = 128
        self.ca_net = CA_NET()

        self.fc = nn.Sequential(
                nn.Linear(inputs, dim *4*4, bias = False),
                nn.BatchNorm1d(dim*4*4),
                nn.ReLU(inplace = True)
                )

        self.up1 = upResolution(dim , dim//2)
        self.up2 = upResolution(dim//2, dim//4)
        self.up3 = upResolution(dim//4, dim//8)
        self.up4 = upResolution(dim//8, dim//16)

        self.img = nn.Sequential(
                conv5x5(dim//16, 3),
                nn.Tanh()
                )

    def forward(self, text_embd, noise):
        #Check shape
        c_code, mu, logvar = self.ca_net(text_embd)
        codes = torch.cat((noise, c_code), 1)
        #codes = torch.cat((noise, text_embd), 1)
        #print("Noise+Embed: {}".format(codes.shape))
        #Check shape
        in_codes = self.fc(codes)
        #print("After FC1: {}".format(in_codes.shape))

        #Check shape
        in_codes = in_codes.view(-1, 128, 4, 4)
        #print("Change shape: {}".format(in_codes.shape))
        in_codes = self.up1(in_codes)
        #print("Upsample 1: {}".format(in_codes.shape))
        in_codes = self.up2(in_codes)
        #in_codes = self.attn(in_codes)
        #print("Upsample 2: {}".format(in_codes.shape))
        in_codes = self.up3(in_codes)
        #print("Upsample 3: {}".format(in_codes.shape))
        in_codes = self.up4(in_codes)
        #print("Upsample 4: {}".format(in_codes.shape))

        #in_codes = self.attn(in_codes)
        #in_codes = self.attn(in_codes)
        #print(in_codes)
        #print(in_codes.shape)

        fakeimg = self.img(in_codes)
        return None, fakeimg, mu, logvar

class discriminator1(nn.Module):
    def __init__(self):
        super(discriminator1, self).__init__()
        self.in_dim = 64
        self.cond_dim = 128
        self.doSeq()
        #self.attn = Self_Attn(512, 'relu')

    def doSeq(self):
        dim = self.in_dim
        #c_dim = self.cond_dim
        self.enc_img = nn.Sequential(
                spectral_norm(nn.Conv2d(in_channels = 3, out_channels = dim, kernel_size= 4,
                          stride = 2, padding=1, bias=False)),
                nn.LeakyReLU(0.2, inplace = True),

                spectral_norm(nn.Conv2d(in_channels = dim, out_channels = dim*2, kernel_size= 4,
                          stride = 2, padding=1, bias=False)),
                #nn.BatchNorm2d(dim*2),
                nn.LeakyReLU(0.2, inplace = True),

                spectral_norm(nn.Conv2d(in_channels = dim*2, out_channels = dim*4, kernel_size= 4,
                          stride = 2, padding=1, bias=False)),
                #nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2, inplace = True),

                spectral_norm(nn.Conv2d(in_channels = dim*4, out_channels = dim*8, kernel_size= 4,
                          stride = 2, padding=1, bias=False)),
                #nn.BatchNorm2d(dim*8),
                nn.LeakyReLU(0.2, inplace = True),
                )
        self.get_cond_logits = getLogits(64,128)

    def forward(self, image):
        img_embd = self.enc_img(image)
        #img_embd = self.attn(img_embd)
        #img_embd = self.attn(img_embd)
        return img_embd
      
class generator2(nn.Module):
    def __init__(self, generator1):
        super(generator2, self).__init__()
        self.in_dim = 128
        self.condition_dim = 128
        self.noise_dim = 100
        self.generator1 = generator1
        # fix parameters of stageI GAN
        for param in self.generator1.parameters():
            param.requires_grad = False
        self.attn = Self_Attn(8, 'relu')
        self.doSeq()

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(4):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def doSeq(self):
        dim = self.in_dim

        self.ca_net = CA_NET()

        self.encoder = nn.Sequential(
            spectral_norm(conv5x5(3, dim)),
            nn.ReLU(True),
            spectral_norm(nn.Conv2d(in_channels = dim, out_channels = dim * 2, kernel_size = 4, 
                      stride = 2,padding = 1, bias=False)),
            nn.BatchNorm2d(dim * 2),
            nn.ReLU(True),
            spectral_norm(nn.Conv2d(in_channels = dim * 2, out_channels = dim * 4,kernel_size= 4, 
                      stride = 2, padding = 1, bias=False)),
            nn.BatchNorm2d(dim * 4),
            nn.ReLU(True))
        self.hr_joint = nn.Sequential(
            spectral_norm(conv5x5(self.condition_dim + dim * 4, dim * 4)),
            nn.BatchNorm2d(dim * 4),
            nn.ReLU(True))
        
        self.residual = self._make_layer(ResBlock, dim * 4)
        
        self.up1 = upResolution(dim * 4, dim * 2)
        self.up2 = upResolution(dim * 2, dim)
        self.up3 = upResolution(dim, dim // 2)
        self.up4 = upResolution(dim // 2, dim // 4)
        self.img = nn.Sequential(
            spectral_norm(conv5x5(dim // 4, 3)),
            nn.Tanh())
        self.rsz = nn.Sequential(
            spectral_norm(conv5x5(3, dim//16)),
            nn.Tanh())
        self.rszs = nn.Sequential(
            spectral_norm(conv5x5(dim//16, 3)),
            nn.Tanh())

    def forward(self, text_embedding, noise):
        _, stage1_img, _, _ = self.generator1(text_embedding, noise)
        stage1_img = self.rsz(stage1_img)
        stage1_img = self.attn(stage1_img)
        stage1_img = self.rszs(stage1_img)
        
        stage1_img = stage1_img.detach()
        encoded_img = self.encoder(stage1_img)
        #print(encoded_img.shape)
        #
        #encoded_img = self.attn(encoded_img)
        #

        c_code, mu, logvar = self.ca_net(text_embedding)
        c_code = c_code.view(-1, self.in_dim, 1, 1)
        c_code = c_code.repeat(1, 1, 16, 16)
        i_c_code = torch.cat([encoded_img, c_code], 1)
        h_code = self.hr_joint(i_c_code)
        h_code = self.residual(h_code)

        h_code = self.up1(h_code)
        h_code = self.up2(h_code)
        h_code = self.up3(h_code)
        h_code = self.up4(h_code)

        fake_img = self.img(h_code)
        #print("Fake img shape: {}", fake_img.shape)
        return stage1_img, fake_img, mu, logvar
      
class discriminator2(nn.Module):
    def __init__(self):
        super(discriminator2, self).__init__()
        self.df_dim = 64
        self.ef_dim = 128
        self.define_module()
        self.attn = Self_Attn(512, 'relu')

    def define_module(self):
        ndf, nef = self.df_dim, self.ef_dim
        self.encode_img = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),  # 128 * 128 * ndf
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),  # 64 * 64 * ndf * 2
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),  # 32 * 32 * ndf * 4
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),  # 16 * 16 * ndf * 8
            nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),  # 8 * 8 * ndf * 16
            nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 32),
            nn.LeakyReLU(0.2, inplace=True),  # 4 * 4 * ndf * 32
            conv5x5(ndf * 32, ndf * 16),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),   # 4 * 4 * ndf * 16
            spectral_norm(conv5x5(ndf * 16, ndf * 8)),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True)   # 4 * 4 * ndf * 8
        )

        self.get_cond_logits = getLogits(ndf, nef, bcondition=True)
        self.get_uncond_logits = getLogits(ndf, nef, bcondition=False)

    def forward(self, image):
        img_embedding = self.encode_img(image)
        img_embedding = self.attn(img_embedding)

        return img_embedding


class getLogits(nn.Module):
    def __init__(self, dim, cond, bcondition=True):
        super(getLogits, self).__init__()
        self.dim = dim
        self.cond = cond
        self.bcondition = bcondition
        if bcondition:
            self.logits = nn.Sequential(
                    conv3x3(dim*8 + cond, dim*8),
                    nn.BatchNorm2d(dim*8),
                    nn.Conv2d(in_channels = dim*8, out_channels=1, kernel_size=4,
                              stride = 4),
                    nn.Sigmoid()
                    )
        else:
            self.logits = nn.Sequential(
                    nn.Conv2d(in_channels = dim *8, out_channels = 1, kernel_size=4,
                    stride=4),
                    nn.Sigmoid()
                )

    def forward(self, codes, c_code = None):
        if self.bcondition and c_code is not None:
            #print("Feature Codes Shape: {}".format(codes.shape))
            #print("Condition Code Shape: {}".format(c_code.shape))
            c_code = c_code.view(-1, self.cond, 1,1)
            #print("After reshape: {}".format(c_code.shape))
            c_code = c_code.repeat(1,1,4,4)
            #print("repeat: {}", c_code.shape)
            codes = torch.cat((codes, c_code), 1)
            #print("Codes+Cond. Code: {}".format(codes.shape))
        else:
            codes = codes
        outputs = self.logits(codes)
        return outputs.view(-1)

class CA_NET(nn.Module):
    def __init__(self):
        super(CA_NET, self).__init__()
        self.t_dim = 1024
        self.c_dim = 128
        self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True)
        self.relu = nn.ReLU()

    def encode(self, text_embedding):
        x = self.relu(self.fc(text_embedding))
        mu = x[:, :self.c_dim]
        logvar = x[:, self.c_dim:]
        return mu, logvar

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        #eps = torch.FloatTensor(std.size()).normal_()
        eps = torch.cuda.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def forward(self, text_embedding):
        mu, logvar = self.encode(text_embedding)
        c_code = self.reparametrize(mu, logvar)
        return c_code, mu, logvar


In [None]:
##########################################################
###LOSS CLASS
##########################################################

import torch
import torch.nn as nn

def disc_loss(dis, real_imgs, fake_imgs, real_labels, fake_labels, conditions):
    criterion = nn.BCELoss()
    batch_size = real_imgs.size(0)
    #Check expln
    cond = conditions.detach()
    #print("Conditions Shape: {}".format(cond.shape))
    fake = fake_imgs.detach()
    #print("Fake imgs Shape: {}".format(fake.shape))
    #print("Real imgs shape: {}", real_imgs.shape)
    real_features = dis(real_imgs)
    #print("Real Features Shape: {}".format(real_features.shape))
    fake_features = dis(fake)
    #print("Fake Features Shape: {}".format(fake_features.shape))
    #real pair
    real_logits = dis.get_cond_logits(real_features, cond)
    #print("Real Logits Shape: {}".format(real_logits.shape))
    dis_error_real = criterion(real_logits, real_labels)
    #wrong pairs
    wrong_logits = dis.get_cond_logits(real_features[:(batch_size-1)], cond[1:])
    #print("Wrong Logits Shape: {}".format(wrong_logits.shape))
    dis_error_wrong = criterion(wrong_logits, fake_labels[1:])
    #fake paris
    fake_logits = dis.get_cond_logits(fake_features, cond)
    #print("Fake Logits Shape: {}".format(fake_logits.shape))
    dis_error_fake = criterion(fake_logits, fake_labels)
    
    if dis.get_uncond_logits is not None:
      real_logits = dis.get_uncond_logits(real_features)
      fake_logits = dis.get_uncond_logits(fake_features)
      uncond_dis_error_real = criterion(real_logits, real_labels)
      uncond_dis_error_fake = criterion(fake_logits, fake_labels)
      errorz = ((dis_error_real+uncond_dis_error_real)/2. + (dis_error_fake +
                 uncond_dis_error_fake + dis_error_wrong)/3.)
      dis_error_real = (dis_error_real + uncond_dis_error_real)/2.
      dis_error_fake = (dis_error_fake + uncond_dis_error_fake)/2.
    else:
      errorz = dis_error_real + (dis_error_fake + dis_error_wrong) * 0.5
    #print("Total Error: {}".format(errorz))
    return errorz, dis_error_real.data.item(), dis_error_wrong.data.item(), dis_error_fake.data.item()

def gen_loss(dis, fake_imgs, real_labels, conditions):
    criterion = nn.BCELoss()
    cond = conditions.detach()
    fake_features = dis(fake_imgs)
    #fake pairs
    fake_logits = dis.get_cond_logits(fake_features, cond)
    errorz = criterion(fake_logits, real_labels)
    if dis.get_uncond_logits is not None:
      fake_logits = dis.get_uncond_logits(fake_features)
      uncond_errorz = criterion(fake_logits, real_labels)
      errorz = errorz + uncond_errorz
    #print('Total Error: {}'.format(errorz))
    return errorz

def KL_loss(mu, logvar):
    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.mean(KLD_element).mul_(-0.5)
    return KLD


In [None]:
##########################################################
###UTILS CLASS
##########################################################

import os
import errno
import numpy as np

from copy import deepcopy

from torch.nn import init
import torch
import torch.nn as nn
import torchvision.utils as vutils

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0.0)
            
def save_img_results(data_img, fake, epoch, image_dir):
    num = 64
    fake = fake[0:num]
    # data_img is changed to [0,1]
    if data_img is not None:
        data_img = data_img[0:num]
        vutils.save_image(
            data_img, '%s/real_samples.png' % image_dir,
            normalize=True)
        # fake.data is still [-1, 1]
        vutils.save_image(
            fake.data, '%s/fake_samples_epoch_%03d.png' %
            (image_dir, epoch), normalize=True)
    else:
        vutils.save_image(
            fake.data, '%s/lr_fake_samples_epoch_%03d.png' %
            (image_dir, epoch), normalize=True)
        
def save_model(netG, netD, epoch, model_dir):
    torch.save(
        netG.state_dict(),
        '%s/netG_epoch_last.pth' % (model_dir))
    torch.save(
        netD.state_dict(),
        '%s/netD_epoch_last.pth' % (model_dir))
    #print('Save G/D models')



In [None]:
##########################################################
###TRAINER CLASS
##########################################################

import os
#print(!ls)
#os.chdir('/content/gdrive/My Drive/Test')

import torch
import torch.nn as nn
#import Dataset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
#import generator, discriminator
#import disc_loss, gen_loss, KL_loss
from tensorboardX import summary
from tensorboardX import FileWriter
#import weights_init, save_img_results, save_model

#OUTPUS
model_dir = '/content/gdrive/My Drive/OUTPUTS/STAGE2WA/model'
image_dir = '/content/gdrive/My Drive/OUTPUTS/STAGE2WA/image'
log_dir = '/content/gdrive/My Drive/OUTPUTS/STAGE2WA/log'
summary_writer = FileWriter(log_dir)

#Networks
gen1 = generator1()

gen = generator2(gen1)
gen.apply(weights_init)

dis = discriminator2()
dis.apply(weights_init)

gen.generator1.load_state_dict(torch.load('/content/gdrive/My Drive/OUTPUTS/netG_epoch_280.pth'))
gen.load_state_dict(torch.load('/content/gdrive/My Drive/OUTPUTS/STAGE2WA/model/netG_epoch_last.pth'))
dis.load_state_dict(torch.load('/content/gdrive/My Drive/OUTPUTS/STAGE2WA/model/netD_epoch_last.pth'))

dis.cuda()
gen.cuda()

image_transform = transforms.Compose([
            transforms.RandomCrop(256),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset = Dataset('/content/gdrive/My Drive/CUB_200_2011/DATA',
                  'train', imsize = 256, transform = image_transform)
#print(dataset)
dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=8, drop_last=True, shuffle=True)
#print(dataloader)

noise = Variable(torch.FloatTensor(8,100))
#print('Noise shape: {}'.format(noise.shape))
fixed_noise = Variable(torch.FloatTensor(8, 100).normal_(0,1), volatile = True)
real_labels = Variable(torch.FloatTensor(8).fill_(1))
fake_labels = Variable(torch.FloatTensor(8).fill_(0))
noise = noise.cuda()
fixed_noise = fixed_noise.cuda()
real_labels = real_labels.cuda()
fake_labels = fake_labels.cuda()

optim_dis = optim.Adam(dis.parameters(), lr = 0.0001, betas = (0.5, 0.999))
optim_gen = optim.Adam(gen.parameters(), lr = 0.0001, betas = (0.5, 0.999))

dis_error=[]
gen_error=[]
count = 0
for epoch in range(128,100):
    for i,datas in enumerate(dataloader, 0):
        real_img_org, txt_embd = datas
        #Torch variables
        real_img = Variable(real_img_org)
        txt_embd = Variable(txt_embd)
        real_img = real_img.cuda()
        txt_embd = txt_embd.cuda()
        #print('Text Embade shape: {}'.format(txt_embd.shape))
        noise.data.normal_(0,1)
        #print('Noise shape: {}'.format(noise.shape))
        #print('GENERATING::')
        _, fake_img, mu, logvar = gen(txt_embd, noise)
        #print("Fake img shape: {}".format(fake_img.shape))
        #update discriminator
        #print('DISCRIMINATOR LOSS::')
        dis.zero_grad()
        errorDz, dis_error_real, dis_error_wrong, dis_error_fake = disc_loss(
                dis, real_img, fake_img, real_labels, fake_labels, mu)
        if epoch > 25:
          errorDz.backward()
          optim_dis.step()
        else:
          if errorDz > 1:
            errorDz.backward()
            optim_dis.step()
        #update generator
        #print('GENERATOR LOSS::')
        gen.zero_grad()
        errorGz = gen_loss(dis, fake_img, real_labels, mu)
        kl_loss = KL_loss(mu, logvar)
        errorGzT = kl_loss * 2.0 + errorGz
        errorGzT.backward()
        optim_gen.step()
        
        count = count+1
        '''
        if i !=0:
          if i%100 == 0:
            summary_D = summary.scalar('D_loss', errorDz.item())
            summary_D_real = summary.scalar('D_loss_real', dis_error_real)
            summary_D_wrong = summary.scalar('D_loss_wrong', dis_error_wrong)
            summary_D_fake = summary.scalar('D_loss_fake', dis_error_fake)
            summary_G = summary.scalar('G_loss', errorGz.item())

            summary_writer.add_summary(summary_D, count)
            summary_writer.add_summary(summary_D_real, count)
            summary_writer.add_summary(summary_D_wrong, count)
            summary_writer.add_summary(summary_D_fake, count)
            summary_writer.add_summary(summary_G, count)
        
            _, fake, _, _ =  gen(txt_embd, fixed_noise)
            save_img_results(real_img_org, fake, epoch, image_dir)
        
            print("#############################################################")
            print("Epoch: {}".format(epoch))
            print("Discriminator Loss: {}".format(errorDz.item()))
            print("Generator Loss: {}".format(errorGz.item()))
            print("Real Loss: {}".format(dis_error_real))
            print("Wrong Loss: {}".format(dis_error_wrong))
            print("Fake Loss: {}".format(dis_error_fake))
            print("#############################################################")
            
            dis_error.append(errorDz.item())
            gen_error.append(errorGzT.item())
         '''
    print("#############################################################")
    print("Epoch: {}".format(epoch))
    print("Discriminator Loss: {}".format(errorDz.item()))
    print("Generator Loss: {}".format(errorGz.item()))
    print("Real Loss: {}".format(dis_error_real))
    print("Wrong Loss: {}".format(dis_error_wrong))
    print("Fake Loss: {}".format(dis_error_fake))
    print("#############################################################")
    dis_error.append(errorDz.item())
    gen_error.append(errorGzT.item())
    _, fake, _, _ =  gen(txt_embd, fixed_noise)
    save_img_results(real_img_org, fake, epoch, image_dir)
    save_model(gen, dis, epoch, model_dir)


