In [None]:
import sys
sys.path.append('/common/users/ppk31/CS543_DL_Proj')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import math
from torch.autograd import Variable
from configs import config
from pytorch_model_summary import summary
from IPython.display import clear_output

from utils import (weights_init, make_train_test_split, load_data, compute_discriminator_loss, 
                   compute_generator_loss, KL_loss, L1_loss, save_img_results, save_model, load_from_checkpoint)
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from dataset import Text2ImgDataset, Text2ImgDataset_reformed
from torch.utils.data import DataLoader
import os

import traceback

torch.cuda.empty_cache()

DATASET = '/freespace/local/ppk31_cs543/Project/Dataset'

In [None]:
if config.text_encoder == "distilbert-base-uncased":
    from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
elif config.text_encoder == "openai/clip-vit-base-patch32":
    from transformers import CLIPTokenizer, CLIPModel, CLIPProcessor

In [None]:
print(f"using GPU: {torch.cuda.is_available()}")
gpus = list(range(torch.cuda.device_count()))
print(f"GPU ids: {gpus}")

torch.random.seed()
torch.manual_seed(0)

torch.cuda.set_device(gpus[0])
cudnn.benchmark=True

In [None]:
def get_tokenizer(text_encoder):
    print(f"using {text_encoder} as text encoder")
    if text_encoder == "distilbert-base-uncased":
        return DistilBertTokenizer.from_pretrained(text_encoder)
    elif text_encoder == "openai/clip-vit-base-patch32":
        return CLIPTokenizer.from_pretrained(text_encoder)

class TextEncoder(nn.Module):
    def __init__(self, text_encoder, pretrained=True):
        super(TextEncoder, self).__init__()
        self.text_encoder = text_encoder
        if text_encoder == "distilbert-base-uncased":
            self.encoder = DistilBertModel.from_pretrained(text_encoder)
        elif text_encoder == "openai/clip-vit-base-patch32":
            self.encoder = CLIPModel.from_pretrained(text_encoder)
        # self.text_embedding = 768
        # self.projection = ProjectionHead('text_projector', self.text_embedding, project_dim)
        self.retrieve_token_index = 0
    
    def forward(self, input_tokens, attention_mask):
        if self.text_encoder == "distilbert-base-uncased":
            out = self.encoder(input_ids = input_tokens, attention_mask = attention_mask)
            last_hidden_states = out.last_hidden_state
            embeddings = last_hidden_states[:, self.retrieve_token_index, :]    # output_dimensions = 768
        elif self.text_encoder == "openai/clip-vit-base-patch32":
            embeddings = self.encoder.get_text_features(input_ids = input_tokens, attention_mask = attention_mask) # output_dimensions = 512
        return embeddings

In [None]:
class Augmented_Projection(nn.Module):
    def __init__(self, stage, gen_channels, gen_dim):
        super(Augmented_Projection, self).__init__()
        self.stage = stage
        self.t_dim = config.text_dim
        self.c_dim = config.condition_dim
        self.z_dim = config.z_dim
        self.gen_in = gen_channels #config.generator_dim * gen_dim
        self.fc = nn.Linear(self.t_dim, self.c_dim * 2)
        self.relu = nn.ReLU()
        if stage == 1:
            self.project = nn.Sequential(
                nn.Linear(self.c_dim + self.z_dim, self.gen_in * gen_dim * gen_dim, bias=False), # bias=False, # 768 -> 192*8*8*8
                nn.BatchNorm1d(self.gen_in * gen_dim * gen_dim),
                nn.ReLU()
            )

    def augment(self, mu, logvar):
        std = logvar.mul(0.5).exp()
        # if config.cuda_is_available:
        #     eps = torch.FloatTensor(std.size()).normal_().cuda()
        # else:
        #     eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(torch.randn(std.size()).float().cuda())
        # eps = Variable(eps)
        # eps.mul(std).add(mu)
        return mu + (std * eps)

    def forward(self, text_embedding, noise=None):
        if noise is None and self.stage==1:
            noise = torch.randn((text_embedding.shape[0], self.z_dim)).float().cuda()
        x = self.relu(self.fc(text_embedding))
        mu = x[:, :self.c_dim]
        logvar = x[:, self.c_dim:]
        c_code = self.augment(mu, logvar)
        
        if self.stage == 1:
            c_code = torch.cat((c_code, noise), dim=1)
            c_code = self.project(c_code)
        
        return c_code, mu, logvar

In [None]:
class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    """

    def __init__(self, channels, out_channels=None, kernel_size=4, stride=2, padding=1, batch_norm=True, activation=True, use_conv=True, bias=False):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.batch_norm = batch_norm
        self.activation = activation
        if use_conv:
            self.op = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        else:
            assert self.channels == self.out_channels
            self.op = nn.AvgPool2d(kernel_size=stride, stride=stride)
        if batch_norm:
            self.batchnorm = nn.BatchNorm2d(out_channels)
        if activation:
            self.activtn = nn.LeakyReLU(0.2)

    def forward(self, x):
        assert x.shape[1] == self.channels
        x = self.op(x)
        if self.batch_norm:
            x = self.batchnorm(x)
        if self.activation:
            x = self.activtn(x)
        return x

class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    """

    def __init__(self, channels, out_channels=None, stride=1, padding=1, batch_norm=True, activation=True, bias=False, use_deconv=False, dropout=False):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.batch_norm = batch_norm
        self.activation = activation
        self.dropout = dropout
        self.use_deconv = use_deconv

        if use_deconv:
            self.deconv = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=4, stride=2, padding=padding, bias=bias) # use when not using interpolate
        else:
            self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=3, stride=stride, padding=padding, bias = bias)
        if batch_norm:
            self.batchnorm = nn.BatchNorm2d(out_channels)
        if activation:
            self.activtn = nn.ReLU()
        if self.dropout:
            self.drop = nn.Dropout2d(0.5)
    
    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_deconv:
            x = self.deconv(x)
        else:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
            x = self.conv(x)
        if self.batch_norm:
            x = self.batchnorm(x)
        if self.activation:
            x = self.activtn(x)
        if self.dropout:
            x = self.drop(x)
        return x

In [None]:
class ResBlock(nn.Module):
    """
    A residual block that can optionally change the number of channels.

    :param in_channels: the number of input channels.
    :param out_channels: if specified, the number of out channels.
    """
    def __init__(
        self,
        in_channels,
        out_channels=None,
        stride = 1,
        padding = 1
    ):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=padding)
        if in_channels == out_channels:
                self.x_residual = nn.Identity()
        else:
            self.x_residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU()
    
    def forward(self, x):
        g = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
        x = self.x_residual(x)
        h = x + g
        return self.relu(h)

In [None]:
class AttentionBlock(nn.Module):
    """
    :param channels: is the number of channels in the feature map
    :param n_heads: is the number of attention heads
    """
    def __init__(self, channels, n_heads=1, cond_channels=None):
        super(AttentionBlock, self).__init__()
        self.channels = channels
        assert (
            channels % n_heads == 0
        ), f"q,k,v channels {channels//n_heads} cannot be constructed for {n_heads} heads, input channels: {channels}"
        self.n_heads = n_heads
        # self.norm1 = nn.GroupNorm(num_groups=16, num_channels=channels, eps=1e-6, affine=True) # num_groups=32
        # self.norm1 = nn.BatchNorm2d(channels)
        self.qkv = nn.Conv2d(channels, channels*3, kernel_size=1)
        self.attention = QKVAttention(self.n_heads)
        if cond_channels is not None:
            # self.norm2 = nn.GroupNorm(num_groups=16, num_channels=cond_channels, eps=1e-6, affine=True) # num_groups=32
            # self.norm2 = nn.BatchNorm2d(cond_channels)
            self.cond_kv = nn.Conv2d(cond_channels, channels*2, kernel_size=1)
        # self.proj_out = nn.Conv1d(channels, channels, kernel_size=1)
    def forward(self, x, cond_out = None):
        b, c, *spatial = x.shape
        h, w = spatial
        # qkv = self.qkv(self.norm1(x).view(b, c, -1)) # b, c*3, h*w
        qkv = self.qkv(x).view(b, -1, h*w) # b, c*3, h*w
        # qkv = self.qkv(x.view(b, c, -1)) # b, c*3, h*w
        if cond_out is not None:
            _, cc, *hw = cond_out.shape
            hh, ww = hw
            # cond_out = self.cond_kv(self.norm2(cond_out).view(b, cc, -1))
            cond_out = self.cond_kv(cond_out).view(b, -1, hh*ww)
            h = self.attention(qkv, cond_out)
        else:
            h = self.attention(qkv)
        # h = self.proj_out(h)
        return x + h.reshape(b, c, *spatial)

class QKVAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads
    def forward(self, qkv, cond_kv=None):
        """
        Apply QKV attention.

        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads) # no. of channels for q,k,v for each head
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        if cond_kv is not None:
            assert cond_kv.shape[1] == self.n_heads * ch * 2
            ek, ev = cond_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
            k = torch.cat([ek, k], dim=-1)
            v = torch.cat([ev, v], dim=-1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

In [None]:
class Generator1(nn.Module):
    def __init__(self, stage):
        super(Generator1, self).__init__()
        self.stage = stage
        self.in_dims = config.in_dims # 4
        self.in_channels = config.generator_dim * 8 # 192*8
        self.channel_mul = config.channel_mul
        self.num_resblocks = config.n_resblocks
        self.use_deconv=config.use_deconv 
        self.dropout=config.dropout
        ch = self.in_channels
        
        self.c_dim = config.condition_dim
        n_heads =  config.attention_heads
        attention_resolutions = config.attention_resolutions
        dims = self.in_dims

        self.aug_project = Augmented_Projection(self.stage, self.in_channels, self.in_dims)

        self.blocks = nn.ModuleList()
        for layer, cmul in enumerate(self.channel_mul):

            for _ in range(self.num_resblocks[layer]): # n_resblocks in stage2 = 2
                self.blocks.append(ResBlock(ch//cmul, ch//cmul, stride=1, padding=1))
            
            if dims in attention_resolutions:
                self.blocks.append(AttentionBlock(ch//cmul, n_heads=n_heads))
            
            if layer < len(self.channel_mul)-1:
                self.blocks.append(Upsample(ch//cmul, ch//self.channel_mul[layer+1], use_deconv=self.use_deconv, dropout=self.dropout))
            
            dims *= 2
        
        self.out = nn.Sequential(
            nn.Conv2d(ch//self.channel_mul[-1], 3, kernel_size=3, padding=1, bias=True),
            nn.Tanh()
        )

    def forward(self, text_embedding, noise=None):
        proj_x, mu, logvar = self.aug_project(text_embedding, noise)
        x = proj_x.view(-1, self.in_channels, self.in_dims, self.in_dims)

        for up in self.blocks:
            x = up(x)
        img_out = self.out(x)
        return img_out, mu, logvar

In [None]:
class Generator2(nn.Module):
    def __init__(self, stage):
        super(Generator2, self).__init__()
        self.stage = stage
        self.in_dims = config.in_dims * config.in_dims # 16
        self.in_channels = config.generator_dim # 192
        self.channel_mul = config.channel_mul_stage2
        self.num_resblocks = config.n_resblocks_stage2
        self.use_deconv=config.use_deconv2 
        self.dropout=config.dropout2
        ch = self.in_channels * 4 
        
        self.c_dim = config.condition_dim
        n_heads =  config.attention_heads
        attention_resolutions = config.attention_resolutions
        dims = self.in_dims

        self.aug_project = Augmented_Projection(self.stage, self.in_channels, self.in_dims)
        
        self.downblocks= nn.Sequential(
            Downsample(3, self.in_channels, kernel_size=3, stride=1, padding=1, batch_norm=False),
            Downsample(self.in_channels, self.in_channels*2),
            Downsample(self.in_channels*2, self.in_channels*4)
        )
        self.combined = nn.Sequential(
            Downsample(self.in_channels*4 + self.c_dim, self.in_channels*4, kernel_size=3, stride=1, padding=1) # 768 x 16 x 16
        )
            
        self.blocks = nn.ModuleList()
        for layer, cmul in enumerate(self.channel_mul):

            for _ in range(self.num_resblocks[layer]): # n_resblocks in stage2 = 2
                self.blocks.append(ResBlock(ch//cmul, ch//cmul, stride=1, padding=1))
            
            if dims in attention_resolutions:
                self.blocks.append(AttentionBlock(ch//cmul, n_heads=n_heads))
            
            if layer < len(self.channel_mul)-1:
                self.blocks.append(Upsample(ch//cmul, ch//self.channel_mul[layer+1], use_deconv=self.use_deconv, dropout=self.dropout if layer<2 else False))
            
            dims *= 2
        
        self.out = nn.Sequential(
            nn.Conv2d(ch//self.channel_mul[-1], 3, kernel_size=3, padding=1, bias=True),
            nn.Tanh()
        )
        print("Initialized stage2 Generator")
        
    def forward(self, text_embedding, stage1_out):
        enc_img = self.downblocks(stage1_out)
        
        proj_x, mu, logvar = self.aug_project(text_embedding)
        x = proj_x.view(-1, self.c_dim, 1, 1)
        x = x.repeat(1, 1, self.in_dims, self.in_dims)
        x = torch.cat([enc_img, x], dim=1)
        x = self.combined(x)

        for up in self.blocks:
            x = up(x)
        img_out = self.out(x)
        return img_out, mu, logvar

In [None]:
class D_Logits(nn.Module):
    def __init__(self, d_ch, c_dim, txt_dim, condition):
        super(D_Logits, self).__init__()
        self.condition = condition
        self.d_ch = d_ch
        self.c_dim = c_dim
        self.txt_dim = txt_dim

        # self.attention = AttentionBlock(channels=self.d_ch*8+self.c_dim, n_heads=2, cond_channels=self.c_dim)
        # self.conv1 = Downsample(self.d_ch*8 + self.c_dim, self.d_ch*8, kernel_size=3, stride=1, padding=1, bias=False)
        # self.attention = AttentionBlock(channels=self.d_ch*8, n_heads=1)

        self.compress = nn.Sequential(
            nn.Linear(self.txt_dim, self.c_dim),
            nn.ReLU()
        )

        if condition:
            self.outlogits = nn.Sequential(
                # Downsample(self.d_ch*8 + self.c_dim, self.d_ch*8, kernel_size=3, stride=1, padding=1, bias=False),
                Downsample(self.d_ch*8 + self.c_dim, self.d_ch*8, kernel_size=1, stride=1, padding=0, bias=False),
                Downsample(self.d_ch*8, 1, stride=4, padding=0, batch_norm=False, activation=False, bias=True),
                nn.Sigmoid(),
            )
        else:
            self.outlogits = nn.Sequential(
                Downsample(self.d_ch*8, 1, stride=4, padding=0, batch_norm=False, activation=False, bias=True),
                nn.Sigmoid(),
            )
    
    def forward(self, feat, cond_out=None):
        if self.condition:
            ### compress text_embeddings using a linear layer
            cond_out = self.compress(cond_out)
            ### reshape
            cond_out = cond_out.view(-1, self.c_dim, 1, 1)
            cond_out = cond_out.repeat(1, 1, 4, 4) # (1, 1, 8, 8) (1,1,config.in_dims,config.in_dims)
            x = torch.cat((feat, cond_out), 1)
        else:
            x = feat
        # x = self.conv1(x)
        # x = self.attention(feat, cond_out)
        # x = self.attention(x)
        out = self.outlogits(x)
        return out.view(-1)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, stage):
        super(Discriminator, self).__init__()
        self.stage = stage
        self.d_ch = config.discriminator_channel
        self.c_dim = config.condition_dim
        self.txt_dim = config.text_dim
        # self.att_logits = config.att_logits

        self.encode = nn.Sequential(
            Downsample(3, self.d_ch, batch_norm=False),
            Downsample(self.d_ch, self.d_ch*2),
            Downsample(self.d_ch*2, self.d_ch*4),
            Downsample(self.d_ch*4, self.d_ch*8),
        )

        if stage == 2:
            self.encode_further = nn.Sequential(
                Downsample(self.d_ch*8, self.d_ch*16),
                Downsample(self.d_ch*16, self.d_ch*32),
                Downsample(self.d_ch*32, self.d_ch*16, kernel_size=3, stride=1, padding=1),
                Downsample(self.d_ch*16, self.d_ch*8, kernel_size=3, stride=1, padding=1),
            )

        self.cond_discriminator_logits = D_Logits(self.d_ch, self.c_dim, self.txt_dim, condition=True)
        self.uncond_discriminator_logits = None
        # if self.stage == 2:
        #     self.uncond_discriminator_logits = D_Logits(self.d_ch, self.c_dim, condition=False)
        print("Initialized, stage {} discriminator".format(stage))

    def forward(self, x):
        x = self.encode(x)
        if self.stage == 2:
            x = self.encode_further(x)
        return x

In [None]:
def get_loader(stage, batch_size, random_captions=True):
        imageSize = None
        if stage == 1:
                imageSize = config.imageSize # 64
        else:
                imageSize = config.imageSize * 4  # 64*4 = 256

        print(f"Genearting Dataset with image size: {imageSize}")

        tokenizer = get_tokenizer(config.text_encoder)

        imageFolder = os.path.join(DATASET, config.dataset, config.imageFolder) # check these params in config before running

        if random_captions:
                train_images, train_captions = load_data(imageListPath=config.trainImageListPath, captionsListPath=config.trainCaptionsListPath)
                train_dataset = Text2ImgDataset_reformed(imageFolder, tokenizer, config.text_encoder, train_images, train_captions, imageSize, augmentImage=False)
        else:
                imageListPath = os.path.join(DATASET, config.dataset, config.imageListPath) # check these params in config before running
                captionsListPath = os.path.join(DATASET, config.dataset, config.captionsListPath) # check these params in config before running
                train_images, train_captions, test_images, test_captions = make_train_test_split(imageListPath, captionsListPath, config.test_size)
                train_dataset = Text2ImgDataset(imageFolder, tokenizer, config.text_encoder, train_images, train_captions, imageSize, augmentImage = False) # change based on stage -imagesize
        
        print("Dataset created:\n\
                length of train dataset: {}\n".format(len(train_dataset)))

        trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

        return trainloader

In [None]:
torch.autograd.set_detect_anomaly(True)

In [None]:
os.makedirs(config.tb_dir, exist_ok=True)       # change these in config when training stage1 and stage2 accordingly
os.makedirs(config.model_out, exist_ok=True)    # change these in config when training stage1 and stage2 accordingly
os.makedirs(config.out_img, exist_ok=True)      # change these in config when training stage1 and stage2 accordingly

In [None]:
def get_controllers(netD, netG):
    d_lr = config.d_lr
    g_lr = config.g_lr
    # We create the optimizer object of the discriminator
    optimizerD = optim.Adam(netD.parameters(), lr = d_lr, betas = (0.5, 0.999))
    scheduler_D = MultiStepLR(optimizerD, milestones=config.lr_decay_epoch, gamma=config.lr_gamma, verbose=True) 
    # We create the optimizer object of the generator.
    optimizerG = optim.Adam(netG.parameters(), lr = g_lr, betas = (0.5, 0.999)) 
    scheduler_G = MultiStepLR(optimizerG, milestones=config.lr_decay_epoch, gamma=config.lr_gamma, verbose=True)

    criterion = nn.BCELoss()
    L1Loss = nn.L1Loss()

    return optimizerD, scheduler_D, optimizerG, scheduler_G, criterion, L1Loss

In [None]:
## Training
def train(stage, batch_size, trainloader):

    noise_dim = config.z_dim
    noise = Variable(torch.FloatTensor(batch_size, noise_dim).float().cuda())
    real_labels = Variable(torch.ones(batch_size).float().cuda())
    fake_labels = Variable(torch.zeros(batch_size).float().cuda())

    assert batch_size == real_labels.shape[0], "batch_size and target size do not match in real_labels"
    assert batch_size == fake_labels.shape[0], "batch_size and target size do not match in fake_labels"

    text_encoder = TextEncoder(config.text_encoder, pretrained=True)
    text_encoder.eval()

    # stage 1 training, only stage 1 g and stage1 d
    if stage == 1:
        netG = Generator1(stage=stage)
        netD = Discriminator(stage=stage)

    # stage 2 training, stage1 g output is fed to stage2 g, stage2 d
    else:
        stage1_G = Generator1(1)
        stage1_G = load_from_checkpoint(stage1_G, config.gen1_ckpt)
        stage1_G.float().cuda()
        # fix parameters of stageI GAN
        for param in stage1_G.parameters():
            param.requires_grad = False
        stage1_G.eval()
        netG = Generator2(stage=stage)
        netD = Discriminator(stage=stage)

    recovered_epoch = 0
    if config.load_checkpoint:
        if stage == 1:
            recovered_epoch, netG, netD = load_from_checkpoint(netG, config.gen1_ckpt, netD, config.d1_ckpt)
        else:
            recovered_epoch, netG, netD = load_from_checkpoint(netG, config.gen2_ckpt, netD, config.d2_ckpt)
    else:
        netG.apply(weights_init)
        netD.apply(weights_init)
    netG.float().cuda()
    netD.float().cuda()

    optimizerD, schedulerD, optimizerG, schedulerG, criterion, L1Loss = get_controllers(netD, netG)

    tb = 'stage' + str(stage) + '_b' + str(batch_size) + '_d' + (str(config.imageSize) if stage==1 else str(config.imageSize*4)) + '_' + str(recovered_epoch)
    summary = SummaryWriter(os.path.join(config.tb_dir, tb))
   
    running_count = 0
    # KL_coeff = torch.linspace(0., config.KL_COEFF, 30)
    # alpha_l1, _ = torch.linspace(0., config.alpha_L1, 30).sort(descending=True)
    KL_coeff = config.KL_COEFF
    alpha_L1 = config.alpha_L1

    print(f"Traininig Stage: {stage}, outputs at: {config.tb_dir}, {config.out_img}, {config.model_out}")

    for epoch in range(recovered_epoch+1, config.max_epoch+1):
        D_loss = 0
        D_real_loss = 0
        D_fake_loss = 0
        D_wrong_loss = 0
        G_loss = 0
        KL_l = 0
        netG.train()
        netD.train()
        for i, batch in enumerate(trainloader):
            with torch.no_grad():
                text_embeddings = text_encoder(batch['input_ids'], batch['attention_mask'])
            text_embeddings = text_embeddings.float().cuda()
            real_images = batch['image']
            caption = None
            if 'caption' in batch:
                caption = batch['caption']

            noise.data.normal_(0,1)
            # noise = torch.randn((batch_size, noise_dim)).float().cuda()
            
            low_res = None
            if stage == 1:
                # Generate fake image
                inputs = (text_embeddings, noise)
                fake_images, mu, logvar = nn.parallel.data_parallel(netG, inputs, gpus)
                assert fake_images.shape[-1] == 64, f"Image size {fake_images.shape[-1]} differs from 64"
            
            else:
                # Generate fake image
                s1_inputs = (text_embeddings, noise)
                with torch.no_grad():
                    low_res, _, _ = nn.parallel.data_parallel(stage1_G, s1_inputs, gpus)
                # pass stage 1 output to generator
                s2_inputs = (text_embeddings, low_res.detach())
                fake_images, mu, logvar = nn.parallel.data_parallel(netG, s2_inputs, gpus)
                assert fake_images.shape[-1] == 256, f"Image size {fake_images.shape[-1]} differs from 256"

            rlabels = real_labels.clone()
            flabels = fake_labels.clone()
            # label smoothing
            if config.label_smoothening:
                r = np.random.rand(1)[0]
                if r <= 0.5:
                    smoothening = np.random.choice(a=np.linspace(0., 0.20, num=5), replace=True, size=batch_size)
                    smoothening = torch.tensor(smoothening).float().cuda()
                    rlabels -= smoothening
                    flabels += smoothening
                # occasionally flip labels
                else:
                    rlabels = np.random.choice(a=[0.,1.], replace=True, size=batch_size, p=[0.05,0.95])
                    flabels = np.random.choice(a=[0.,1.], replace=True, size=batch_size, p=[0.95,0.05])
                    rlabels = torch.tensor(rlabels).float().cuda()
                    flabels = torch.tensor(flabels).float().cuda()

            # Update discriminator network
            netD.zero_grad()
            errD, errD_real, errD_wrong, errD_fake = compute_discriminator_loss(netD, criterion, real_images, fake_images,
                                                                                rlabels, flabels, text_embeddings, gpus)
            errD.backward()

            # Gradient Norm Clipping
            if config.clip_grad:
                nn.utils.clip_grad_norm_(netD.parameters(), max_norm=2.0, norm_type=2)

            optimizerD.step()
            D_loss += errD.item()
            D_real_loss += errD_real
            D_fake_loss += errD_fake
            D_wrong_loss += errD_wrong

            # Update generator network
            netG.zero_grad()
            errG_fake = compute_generator_loss(netD, criterion, fake_images, real_images, 
                                               rlabels, text_embeddings, gpus)
            kl_loss = KL_loss(mu, logvar)
            errG_total = errG_fake +  (KL_coeff * kl_loss)
            if alpha_L1 > 0:
                errG_L1 = L1_loss(L1Loss, fake_images, real_images)
                errG_total += (alpha_L1 * errG_L1)
            
            # annealing KL_coeff and L1_loss
            # if epoch -1 < 30:
            #     kld_coeff = KL_coeff[epoch-1].item()
            #     l1_coeff = alpha_l1[epoch-1].item()
            # else:
            #     kld_coeff = KL_coeff[-1].item()
            #     l1_coeff = alpha_l1[-1].item()

            errG_total.backward()

            # Gradient Norm Clipping
            if config.clip_grad:
                nn.utils.clip_grad_norm_(netG.parameters(), max_norm=2.0, norm_type=2)

            optimizerG.step()
            G_loss += errG_total.item()
            KL_l += kl_loss.item()

            running_count += 1
            if i%100 == 0:
                print('[%d/%d] [%d/%d] Loss_D: %.5f, Loss_G: %.5f, Loss_KL: %.5f' % (epoch, config.max_epoch, i, len(trainloader), errD.item(), errG_total.item(), kl_loss.item()))
                save_img_results(real_images, fake_images, low_res,  caption, epoch, config.out_img)

        summary.add_scalars('Discriminator', {'DLoss':D_loss/len(trainloader), 
                                              'RealLoss':D_real_loss/len(trainloader), 
                                              'FakeLoss':D_fake_loss/len(trainloader), 
                                              'WrongLoss':D_wrong_loss/len(trainloader)}, epoch)
        summary.add_scalars('Generator', {'GLoss':G_loss/len(trainloader), 
                                          'KL_Loss':KL_l/len(trainloader)}, epoch)
        # summary.add_scalars('Grad_Norm', {'D':np.mean(d_grad_norm),
        #                                   'G':np.mean(g_grad_norm)}, epoch) 
        print('[%d/%d] Loss_D: %.5f, Loss_G: %.5f, Loss_KL: %.5f' % (epoch, config.max_epoch, D_loss/len(trainloader), G_loss/len(trainloader), KL_l/len(trainloader)))
        
        schedulerD.step()
        schedulerG.step()
        
        if epoch % config.save_snapshot == 0:
            save_model(netG, netD, epoch, config.model_out, stage=stage) 

In [None]:
# """
# call model.eval() before feeding the data, as this will change the behavior of the BatchNorm layer 
# to use the running estimates instead of calculating them for the current batch
# """
# netG1.eval()
# netD1.eval()
# clear_output()

In [None]:
stage=1
batch_size = config.batch_size * len(gpus)
trainloader = get_loader(stage, batch_size, random_captions=True)

In [None]:
train(stage, batch_size, trainloader)

In [None]:
stage=2
batch_size = config.batch_size * len(gpus)
trainloader = get_loader(stage, batch_size, random_captions=True)

In [None]:
train(stage, batch_size, trainloader)