In [None]:
import math
import os
import json
import pickle
import datetime

import torchvision.transforms as transforms
from torchvision.utils import make_grid

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from tqdm.auto import tqdm 

plt.ion()
from IPython.display import clear_output

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
np.random.seed(0)

# Models

In [None]:
text_model_name = "cointegrated/rubert-tiny2"
tokenizer = AutoTokenizer.from_pretrained(text_model_name)
text_model = AutoModel.from_pretrained(text_model_name).eval().to(device)

def embed_bert_cls(text):
    t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = text_model(**{k: v.to(text_model.device) for k, v in t.items()})
    embeddings = model_output.last_hidden_state[:, 0, :]
#     embeddings = torch.nn.functional.normalize(embeddings)
    return embeddings

In [None]:
class Hyperparameters(object):
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)


hp = Hyperparameters(n_epochs=1000,
                     batch_size=128,
                     train_size=1.0,
                     lr=.0002,
                     b1=.5,
                     b2=0.999,
                     latent_dim=100,
                     text_emb_size=text_model.config.hidden_size,
                     img_size=64,
                     channels=3,
                     ngf=64,
                     ndf=64,
                     max_desc_words=4,
                     
                     write_logs=True,
                     sample_interval=1,  # epochs
                     save_interval=10,  # epochs
                     n_sampled_examples=50,
                     sampled_grid_shape=(5, 10),
                     num_workers=1,
                    )

print(hp.lr)

In [None]:
cuda = True if torch.cuda.is_available() else False
img_shape = (hp.channels, hp.img_size, hp.img_size)


def to_img(x):
    x = x.clamp(0, 1)
    return x


def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


def visualise_output(images, texts, x, y, save_path=None):
    with torch.no_grad():
        images = images.cpu()
        images = to_img(images)
        fig, axes = plt.subplots(x, y, figsize=(32, 20))
        item_idx = 0
        for i in range(x):
            for j in range(y):
                splitted_text = texts[item_idx].split()
                cur_text = [" ".join(splitted_text[x:x + 2]) for x in range(0, len(splitted_text), 2)]
                axes[i][j].imshow(np.transpose(images[item_idx], (1, 2, 0)))
                axes[i][j].set_title("\n".join(cur_text))
                axes[i][j].axis("off")
                item_idx += 1
                if item_idx >= len(texts):
                    break
                
            if item_idx >= len(texts):
                break
        
        plt.subplots_adjust(hspace=-0.1, wspace=0.15)
        if save_path is not None:
            fig.savefig(save_path, dpi=fig.dpi)
        plt.show()

In [None]:
# Conditional DCGAN
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.text_layer = nn.Sequential(
            nn.ConvTranspose2d(hp.text_emb_size, hp.ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(hp.ngf * 4),
            nn.ReLU(True)
            # state size. ``(ngf*4) x 4 x 4``
        )
        self.noise_layer = nn.Sequential(
            nn.ConvTranspose2d(hp.latent_dim, hp.ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(hp.ngf * 4),
            nn.ReLU(True)
            # state size. ``(ngf*4) x 4 x 4``
        )
        self.model = nn.Sequential(
            nn.ConvTranspose2d(hp.ngf * 8, hp.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hp.ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d(hp.ngf * 4, hp.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hp.ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d(hp.ngf * 2, hp.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hp.ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d(hp.ngf, hp.channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, noise, text_emb):
        noise_out = self.noise_layer(noise)
        text_out = self.text_layer(text_emb)
        gen_input = torch.cat([text_out, noise_out], 1)
        img = self.model(gen_input)
        return img
    

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.text_layer = nn.Sequential(
            nn.Linear(hp.text_emb_size, (hp.img_size ** 2)),
            nn.ReLU(inplace=True),
        )
        self.model = nn.Sequential(
            # input is ``(nc + 1 for text) x 64 x 64``
            nn.Conv2d(hp.channels + 1, hp.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf / 2) x 32 x 32``
            nn.Conv2d(hp.ndf, hp.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hp.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(hp.ndf * 2, hp.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hp.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(hp.ndf * 4, hp.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hp.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(hp.ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, img, text_emb):
        text_out = self.text_layer(text_emb.squeeze())
        text_out = text_out.view(-1, 1, hp.img_size, hp.img_size)
        d_input = torch.cat([img, text_out], 1)
        validity = self.model(d_input).view(-1, 1)
        return validity


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        

generator = Generator().apply(weights_init)
discriminator = Discriminator().apply(weights_init)

In [None]:
# loss_fn = torch.nn.MSELoss()
loss_fn = torch.nn.BCELoss()
if cuda:
    generator.cuda()
    discriminator.cuda()
    loss_fn.cuda()

In [None]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=hp.lr, betas=(hp.b1, hp.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=hp.lr, betas=(hp.b1, hp.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

# Data

In [None]:
train_path = f"prep_data_size_{int(hp.img_size)}.pkl"
with open(train_path, "rb") as f:
    train_raw = pickle.load(f)

test_ids = np.random.choice(np.arange(len(train_raw)), hp.n_sampled_examples, replace=False)
test_raw = [x for idx, x in enumerate(train_raw) if idx in test_ids]
len(train_raw), len(test_raw)

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])

In [None]:
class ImageDataset(Dataset):
    def __init__(self, dataset, transform, image_size):
        self.transform = transform
        self.image_size = image_size
        self.data = dataset

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image, item_type, item_emotions = self.data[idx]
        item_type = [item_type] if np.random.sample() < 0.5 else []
        
        choose_n = np.random.randint(1, min(hp.max_desc_words, len(item_emotions)))
        emotions = np.random.choice(item_emotions, size=choose_n, replace=False).tolist()
        
        text_description = [" ".join(item_type + emotions)]
        return self.transform(image), text_description

In [None]:
train_data = ImageDataset(train_raw, transform, hp.img_size)
test_data = ImageDataset(test_raw, transform, hp.img_size)
fixed_noise = torch.randn((hp.n_sampled_examples, hp.latent_dim, 1, 1)).to(device)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=hp.batch_size,
                                               shuffle=True, drop_last=True)
train_dataloader_fake = torch.utils.data.DataLoader(train_data, batch_size=hp.batch_size,
                                                    shuffle=True, drop_last=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=hp.n_sampled_examples, shuffle=False)
test_data_fixed_example = iter(test_dataloader).next()

In [None]:
def sample_image(x, y, save_path=None):
    images, texts = test_data_fixed_example
    labels = embed_bert_cls(texts[0]).to(device)[..., None, None]
    with torch.no_grad():
        gen_imgs = generator(fixed_noise, labels)
    
    visualise_output(gen_imgs, texts[0], x, y, save_path=save_path)

# sample_image(5, 10)

In [None]:
# logs
gan_algorithm = "conditional_dcgan"
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

samples_save_path = f"samples_save_path/{gan_algorithm}_{timestamp}"
os.makedirs(samples_save_path, exist_ok=True)

saved_models_path = f"saved_models/{gan_algorithm}_{timestamp}"
os.makedirs(saved_models_path, exist_ok=True)

write_logs = hp.write_logs
if write_logs:
    tb_log_name_parts = [f"{k}_{v}" for k, v in list(hp.__dict__.items())[:-6]]
    tb_log_dir = f"tb_logs_{gan_algorithm}/{timestamp}_" + "_".join(tb_log_name_parts)
    os.makedirs(tb_log_dir, exist_ok=True)
    writer = SummaryWriter(tb_log_dir)

In [None]:
# Adversarial ground truths
valid = Variable(FloatTensor(hp.batch_size).fill_(1.0), requires_grad=False).to(device)
fake = Variable(FloatTensor(hp.batch_size).fill_(0.0), requires_grad=False).to(device)

for epoch in range(hp.n_epochs):
    g_losses = []
    d_real_losses = []
    d_fake_losses = []
    d_real_vals = []
    d_fake_vals = []
    d_gen_vals = []
    for i, ((imgs, texts), (_, random_texts)) in enumerate(zip(train_dataloader, 
                                                               train_dataloader_fake)):
        batch_size = imgs.shape[0]
        real_imgs = imgs.to(device)
        all_texts = texts[0] + random_texts[0]
        all_labels = embed_bert_cls(all_texts).to(device)[..., None, None]
        gen_labels = all_labels[:hp.batch_size, :, :, :]
        random_labels = all_labels[hp.batch_size:, :, :, :]
        
        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()
        validity_real = discriminator(real_imgs, gen_labels).view(-1)
        d_real_loss = loss_fn(validity_real, valid)  # Loss for real images

        # Sample noise and labels as generator input
        z = torch.randn((batch_size, hp.latent_dim, 1, 1)).to(device)
        gen_imgs = generator(z, random_labels)  # Generate a batch of images
        
        validity_fake = discriminator(gen_imgs.detach(), random_labels).view(-1)
        d_fake_loss = loss_fn(validity_fake, fake)  # Loss for fake images

        d_loss = d_real_loss + d_fake_loss  # Total discriminator loss
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Loss measures generator's ability to fool the discriminator
        z = torch.randn((batch_size, hp.latent_dim, 1, 1)).to(device)
        gen_imgs = generator(z, gen_labels)  # Generate a batch of images
        
        validity = discriminator(gen_imgs, gen_labels).view(-1)
        g_loss = loss_fn(validity, valid)

        g_loss.backward()
        optimizer_G.step()
        
        g_losses.append(g_loss.item())
        d_real_losses.append(d_real_loss.item())
        d_fake_losses.append(d_fake_loss.item())
        d_real_vals.append(validity_real.detach().cpu().numpy().mean())
        d_fake_vals.append(validity_fake.detach().cpu().numpy().mean())
        d_gen_vals.append(validity.detach().cpu().numpy().mean())

    if not epoch % hp.sample_interval:
        clear_output()
        dloss_real = round(np.mean(d_real_losses), 5)
        dloss_fake = round(np.mean(d_fake_losses), 5)
        gloss = round(np.mean(g_losses), 5)
        print(f"Epoch: {epoch+1}/{hp.n_epochs}, DLoss Real: {dloss_real}, DLoss Fake: {dloss_fake}, GLoss: {gloss}")
        img_save_path = f"{samples_save_path}/{epoch}.png"
        sample_image(hp.sampled_grid_shape[0], hp.sampled_grid_shape[1], img_save_path)
    
    if not epoch % hp.save_interval:
        torch.save(generator.state_dict(), f"{saved_models_path}/generator_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"{saved_models_path}/discriminator_{epoch}.pth")
    
    if write_logs:
        writer.add_scalar("losses/generator_loss_mean", np.mean(g_losses), epoch)
        writer.add_scalar("losses/generator_loss_std", np.std(g_losses), epoch)
        writer.add_scalar("losses/discriminator_real_loss_mean", np.mean(d_real_losses), epoch)
        writer.add_scalar("losses/discriminator_real_loss_std", np.std(d_real_losses), epoch)
        writer.add_scalar("losses/discriminator_fake_loss_mean", np.mean(d_fake_losses), epoch)
        writer.add_scalar("losses/discriminator_fake_loss_std", np.std(d_fake_losses), epoch)
        writer.add_scalar("logs/validity_real", np.mean(d_real_vals), epoch)
        writer.add_scalar("logs/validity_fake", np.mean(d_fake_vals), epoch)
        writer.add_scalar("logs/validity_gen", np.mean(d_gen_vals), epoch)
        
if write_logs:
    writer.close()
    
# final save
torch.save(generator.state_dict(), f"{saved_models_path}/generator_final.pth")
torch.save(discriminator.state_dict(), f"{saved_models_path}/discriminator_final.pth")