In [1]:
# from zipfile import ZipFile

# with ZipFile('compressed.zip', 'r') as zf:
#     zf.extractall('.')

In [2]:
# !pip install torchinfo
# !pip install pandas
# !pip install matplotlib
# !pip install scipy

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchinfo import summary
from torchvision import transforms

from models.Discriminator import *
from models.Encoder import *
from models.Generator import *
from loss_functions import *
from DAMSM_trainer import *

import dataset
import os
import matplotlib.pyplot as plt
import config.settings as config

In [4]:
import time

cur = time.time()

imsize = 299
transform=transforms.Compose([
        transforms.Resize(int(imsize)),
        transforms.RandomCrop(imsize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
])

data_set = dataset.TextDataset(os.getcwd(), transform=transform)

print(time.time() - cur)

10.926204204559326


In [5]:
def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.orthogonal_(m.weight.data, 1.0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        nn.init.orthogonal_(m.weight.data, 1.0)
        if m.bias is not None:
            m.bias.data.fill_(0.0)

In [6]:
cnn_model = CNN_Encoder(config.EMBEDDING_DIM).to(config.DEVICE)
rnn_model = RNN_Encoder(config.WORD_SIZE, number_hidden=config.EMBEDDING_DIM).to(config.DEVICE)

rnn_model.load_state_dict(torch.load("saved_models/rnn_model_state_dict.pt"))
cnn_model.load_state_dict(torch.load("saved_models/cnn_model_state_dict.pt"))

cnn_model = cnn_model.to(config.DEVICE)
rnn_model = rnn_model.to(config.DEVICE)



In [7]:
discriminators = [DiscriminatorNetwork().to(config.DEVICE), DiscriminatorNetwork(down_sample_count=1).to(config.DEVICE)]
generator = GenerativeNetwork().to(config.DEVICE)

generator.load_state_dict(torch.load("saved_models/generator_state_dict.pt"))
generator.to(config.DEVICE)
for i in range(len(discriminators)):
    discriminators[i].load_state_dict(torch.load(f"saved_models/discriminator{i}_state_dict.pt"))
    discriminators[i].to(config.DEVICE)

In [8]:
def get_optimizer(generator, discriminators):
    d_optimizers = []
    for i in range(len(discriminators)):
        d_optimizers.append(optim.Adam(discriminators[i].parameters(),
                                       lr=config.LR,
                                       betas=(0.5, 0.999)))
    g_optimizer = optim.Adam(generator.parameters(),
                             lr=config.LR,
                            betas=(0.5, 0.999))
    return g_optimizer, d_optimizers

In [9]:
from torch.utils.data import DataLoader
import torch.optim as optim
data_loader = DataLoader(data_set, batch_size=2, shuffle=True)

para = list(rnn_model.parameters())
for p in cnn_model.parameters():
    if p.requires_grad:
        para.append(p)

config.RNN_GRAD = 0.25

print(data_loader.batch_size)


2


In [10]:
def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
                       real_labels, fake_labels):
    # Forward
    real_features = netD(real_imgs)
    fake_features = netD(fake_imgs.detach())
    # loss
    #
    cond_real_logits = netD.conditional_discriminator(real_features, conditions)
    cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels)
    cond_fake_logits = netD.conditional_discriminator(fake_features, conditions)
    cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels)
    #
    batch_size = real_features.size(0)
    cond_wrong_logits = netD.conditional_discriminator(real_features[:(batch_size - 1)], conditions[1:batch_size])
    cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size])

    if netD.unconditional_discriminator is not None:
        real_logits = netD.unconditional_discriminator(real_features)
        fake_logits = netD.unconditional_discriminator(fake_features)
        real_errD = nn.BCELoss()(real_logits, real_labels)
        fake_errD = nn.BCELoss()(fake_logits, fake_labels)
        errD = ((real_errD + cond_real_errD) / 2. +
                (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.)
    else:
        errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2.
    return errD


In [11]:
def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                   words_embs, sent_emb, match_labels,
                   cap_lens, class_ids):
    numDs = len(netsD)
    batch_size = real_labels.size(0)
    
    logs = ''
    # Forward
    errG_total = 0
    for i in range(numDs):
        features = netsD[i](fake_imgs[i])
        cond_logits = netsD[i].conditional_discriminator(features, sent_emb)
        cond_errG = nn.BCELoss()(cond_logits, real_labels)
        if netsD[i].unconditional_discriminator is  not None:
            logits = netsD[i].unconditional_discriminator(features)
            errG = nn.BCELoss()(logits, real_labels)
            g_loss = errG + cond_errG
        else:
            g_loss = cond_errG
        errG_total += g_loss
        # err_img = errG_total.data[0]
        
        # Ranking loss
        if i == (numDs - 1):
            # words_features: batch_size x nef x 17 x 17
            # sent_code: batch_size x nef
            region_features, cnn_code = image_encoder(fake_imgs[i])
            w_loss0, w_loss1, _ = word_loss(region_features, words_embs,
                                             cap_lens, match_labels,
                                             class_ids, batch_size)
            w_loss = (w_loss0 + w_loss1) * config.LAMBDA
            # err_words = err_words + w_loss.data[0]

            s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb,
                                         match_labels, class_ids, batch_size)
            s_loss = (s_loss0 + s_loss1) * config.LAMBDA
            # err_sent = err_sent + s_loss.data[0]

            errG_total += w_loss + s_loss
    return errG_total


In [12]:
from torchvision.utils import save_image

def save_generate_img(generator, epoch, statics):
    fake_imgs, _, _, _ = generator(*statics)
    imgx64 = fake_imgs[0]
    imgx128 = fake_imgs[1]

    save_image(imgx64[0], f"x64/{epoch}.png")
    save_image(imgx128[0], f"x128/{epoch}.png")

In [13]:
def train(cnn_model, rnn_model, generator, discriminators, data_loader, damsm_train=False):
    generator.train()
    for i in range(len(discriminators)):
        discriminators[i].train()

    g_optim, d_optims = get_optimizer(generator, discriminators)
    rc_optimizer = get_rc_optimizer(rnn_model, cnn_model)

    real_labels = torch.ones((data_loader.batch_size,)).float().to(config.DEVICE)
    fake_labels = torch.zeros((data_loader.batch_size,)).float().to(config.DEVICE)
    labels = torch.LongTensor(list(range(data_loader.batch_size))).to(config.DEVICE)
    z_noise = config.Z_DIM
    
    torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), config.RNN_GRAD)
    
    statics = None

    for epoch in range(1):
        cur_time = time.time()
        for idx, (imgs, caps, lengths, ids) in enumerate(data_loader):
            batch_size = imgs.size(0)
            if batch_size < data_loader.batch_size: continue

            imgs, caps, lengths, ids = imgs.to(config.DEVICE), caps.to(config.DEVICE), lengths.to(config.DEVICE), ids.to(config.DEVICE)

            noises = torch.FloatTensor(batch_size, z_noise).to(config.DEVICE)

            hidden = rnn_model.init_hidden(batch_size)

            word_embs, sentences_emb = rnn_model(caps.long(), lengths.long(), hidden)

            if damsm_train:
                features, cnn_code = cnn_model(imgs)
                w_loss0, w_loss1, _ = word_loss(features, word_embs, lengths, labels, ids, batch_size)
                s_loss0, s_loss1 = sent_loss(cnn_code, sentences_emb, labels, ids, batch_size)

                loss = w_loss0 + w_loss1 + s_loss0 + s_loss1

                loss.backward()

                rc_optimizer.step()

            word_embs, sentences_emb = word_embs.detach(), sentences_emb.detach()

            mask = (caps == 0).to(config.DEVICE)
            num_words = word_embs.size(2)

            if mask.size(1) > num_words:
                mask = mask[:, :num_words]

            # create fakes image
            noises.data.normal_(0, 1)
            if statics is None:
                statics = noises, sentences_emb, word_embs, mask

            imgs = [F.interpolate(imgs, 64), F.interpolate(imgs, 128)]

            total_d_loss = 0
            
            for _ in range(config.DISCRIMINATOR_REPEAT):
                for i in range(len(discriminators)):
                    d_noises = torch.randn((batch_size, z_noise)).to(config.DEVICE)
                    fake_imgs = generator(d_noises, sentences_emb, word_embs, mask)[0]
                    discriminators[i].zero_grad()
                    loss = discriminator_loss(discriminators[i], imgs[i], fake_imgs[i], sentences_emb, real_labels, fake_labels)
                    loss.backward()
                    d_optims[i].step()

                    total_d_loss += loss 
                    
            fake_imgs, _, mu, log_var = generator(noises, sentences_emb, word_embs, mask)
            # update generator
            generator.zero_grad()
            total_g_loss = generator_loss(discriminators, cnn_model, fake_imgs, real_labels,
                                          word_embs, sentences_emb, labels, lengths, ids)
            total_g_loss += torch.mean(mu.pow_(2).add_(log_var.exp()).mul_(-1).add_(1).add_(log_var)).mul_(-0.5)
            total_g_loss.backward()
            g_optim.step()
            # if idx % 100 == 0:
            #     print(idx, end=' ')
        save_generate_img(generator, epoch, statics)
        print(f"Epoch {epoch} time {time.time() - cur_time}")
        print(f"g_loss: {total_g_loss}, d_loss: {total_d_loss}")

        if damsm_train:
            save_models(rnn_model, cnn_model, ("rnn_model_state_dict.pt", "cnn_model_state_dict.pt"))
    return generator, discriminators

In [14]:
train(cnn_model, rnn_model, generator, discriminators, data_loader, damsm_train=True)

torch.Size([2, 2])


  return self._call_impl(*args, **kwargs)


torch.Size([2, 2])


OutOfMemoryError: CUDA out of memory. Tried to allocate 36.00 MiB. GPU 0 has a total capacity of 1.94 GiB of which 21.19 MiB is free. Including non-PyTorch memory, this process has 1.60 GiB memory in use. Of the allocated memory 1.41 GiB is allocated by PyTorch, and 113.48 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
torch.save(generator.state_dict(), "saved_models/generator_state_dict.pt")
torch.save(discriminators[0].state_dict(), "saved_models/discriminator0_state_dict.pt")
torch.save(discriminators[1].state_dict(), "saved_models/discriminator1_state_dict.pt")

torch.save(rnn_model.state_dict(), "saved_models/rnn_model_state_dict.pt")
torch.save(cnn_model.state_dict(), "saved_models/cnn_model_state_dict.pt")