In [2]:
!mkdir text2landmark_68_2d
%cd text2landmark_68_2d
!gdown 1-2phy4gND3k1F7Wi-HB4qzGs5ImKztrp
!unzip text2landmark_68_2d.zip
!rm text2landmark_68_2d.zip
%cd ..

mkdir: cannot create directory ‘text2landmark_68_2d’: File exists
/kaggle/working/text2landmark_68_2d
Downloading...
From: https://drive.google.com/uc?id=1-2phy4gND3k1F7Wi-HB4qzGs5ImKztrp
To: /kaggle/working/text2landmark_68_2d/text2landmark_68_2d.zip
100%|██████████████████████████████████████| 9.74M/9.74M [00:00<00:00, 68.2MB/s]
Archive:  text2landmark_68_2d.zip
replace landmarks.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C
/kaggle/working


In [33]:
!pip install git+https://github.com/openai/CLIP.git
# !pip install transformers

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-g1qk1z4k
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-g1qk1z4k
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25ldone


In [34]:
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
import clip
# from transformers import CLIPTokenizer, CLIPTextModel

import os
import math
import time
import pickle
import random
import collections
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [35]:
class CelebALandmark(Dataset):
    def __init__(self, data_path, size=1024, split=0):
        """
        split: 0 train, 1 testing, 2 validation
        """
        self.size = size
        self.landmarks_path = os.path.join(data_path, 'landmarks.txt')
        self.captions_path = os.path.join(data_path, 'captions.pickle')
        self.splits_path = os.path.join(data_path, 'splits')
        self.split = split
        self.names = []
        self.landmarks = []
        self.captions = []

        if split == 0:
            self.filenames = pickle.load(open(f"{self.splits_path}/train_filenames.pickle", "rb"))
        elif split == 1:
            self.filenames = pickle.load(open(f"{self.splits_path}/test_filenames.pickle", "rb"))
        self.all_captions = pickle.load(open(self.captions_path, "rb"))

        all_labels = open(self.landmarks_path, 'r').readlines()
        for label in all_labels:
            name, shape, _ = self.parse_label(label)
            filename = name.split('.')[0]
            if filename in self.filenames:
                self.names.append(name)
                shape = np.array(shape) // (self.size // 64) # convert to 64
#                 shape = shape[:36]
                heatmaps = self.keypoints2heatmaps(shape, size=(64, 64))
                self.landmarks.append(heatmaps)
                captions = self.all_captions[f'{filename}.txt']
                self.captions.append(captions)

    def parse_label(self, label):
        l = label.strip().split()
        name = l[0]
        shape = []
        ori_gaze = []
        if len(l) > 107:
            w_ori, h_ori = [int(_) for _ in l[1].split('-')]
            for l_ in l[2:108]:
                w, h = [int(_) for _ in l_.split('-')]
                shape.append([w, h])
            for l_ in l[108:]:
                ori_gaze.append(int(l_))
        else:
            for l_ in l[1:]:
                l_s = l_.split('-')
                if l_.startswith('-'):
                    w, h = -int(l_s[1]), int(l_s[2])
                elif l_s[1] == '':
                    w, h = int(l_s[0]), -int(l_s[2])
                else:
                    w, h = int(l_s[0]), int(l_s[1])
                h = self.size - h # flip landmarks due to problem in saving
                shape.append([w, h])

        return name, shape, ori_gaze

    def keypoints2heatmaps(self, keypoints, size=(64, 64)):
        keypoints = keypoints.astype(int)
        heatmaps = np.zeros(size + (keypoints.shape[0],), dtype=np.int8)
        for k in range(keypoints.shape[0]):
            x, y = keypoints[k]
            x, y = min(x, size[0] - 1), min(y, size[0] - 1)
            if x < 0 or y < 0:
                continue
            heatmaps[y, x, k] = 1
        return heatmaps

    def __len__(self):
        return len(self.landmarks)

    def __getitem__(self, item):
        landmark = self.landmarks[item]
        caption = random.choice(self.captions[item])
        caption_1 = random.choice(random.choice(self.captions))
        caption_2 = random.choice(random.choice(self.captions))

        return landmark, caption, caption_1, caption_2

In [36]:
def init_networks(network, name, init_type='normal', init_gain=0.02, verbose=False):
        def init_params(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, init_gain)
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, init_gain)
                elif init_type == 'xavier_normal':
                    nn.init.xavier_normal_(m.weight.data, init_gain)
                elif init_type == 'kaiming_normal':
                    nn.init.kaiming_normal_(m.weight.data)
                else:
                    raise NotImplementedError(init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, init_gain)
                nn.init.constant_(m.bias.data, 0.0)
        
        network.apply(init_params)
        if verbose:
            print(f'[INFO] Network {name} initialized')

In [37]:
"""
Inspired from TIPS: Text-Induced Pose Synthesis Stage-1 network
"""


import torch
import torch.nn as nn


def linear(in_features, out_features, bias=True):
    return nn.Sequential(
        nn.Linear(in_features, out_features, bias=bias),
        nn.LeakyReLU(inplace=True)
    )


def upconv4x(in_channels, out_channels, bias=False):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, 4, 4, 0, bias=bias),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )


def upconv2x_hidden(in_channels, out_channels, bias=False):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=bias),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )


def upconv2x_output(in_channels, out_channels, bias=False):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=bias),
        nn.Tanh()
    )


def conv2x(in_channels, out_channels, bias=False):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=bias),
        nn.LeakyReLU(0.2, True)
    )


def conv1x(in_channels, out_channels, bias=False):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),
        nn.LeakyReLU(0.2, True)
    )


def conv_output(in_channels, out_channels, bias=False):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=0, bias=bias),
        # nn.Sigmoid()
    )


class TextEncoder(nn.Module):
    def __init__(self, device='cpu'):
        super(TextEncoder, self).__init__()

        # Text Transformer
        self.clip, _ = clip.load('ViT-B/32', device)
        self.clip.initialize_parameters()

    def encode_text(self, text, device):
        with torch.no_grad():
            text = clip.tokenize(text, truncate=True).to(device)
            x = self.clip.token_embedding(text).type(self.clip.dtype)  # [batch_size, n_ctx, d_model]

            x = x + self.clip.positional_embedding.type(self.clip.dtype)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = self.clip.transformer(x)
            x = self.clip.ln_final(x).type(self.clip.dtype)
            # B, D
            x = x[-1]

        return x

    def forward(self, text, device):
        return self.encode_text(text, device)


class Text2LandmarkG(nn.Module):
    def __init__(self, noise_dim=128, heatmap_channels=68, ngf=32):
        super(Text2LandmarkG, self).__init__()
        self.combined_up1 = upconv4x(noise_dim*2, ngf*8)
        self.combined_up2 = upconv2x_hidden(ngf*8, ngf*4)
        self.combined_up3 = upconv2x_hidden(ngf*4, ngf*2)
        self.combined_up4 = upconv2x_hidden(ngf*2, ngf)
        self.combined_up5 = upconv2x_output(ngf, heatmap_channels)

        self.text_linear = linear(512, noise_dim)
        # self.text_trans_encoder = nn.TransformerEncoder(textTransEncoderLayer,
        #                                                 num_layers=num_text_layers)
        # self.text_ln = nn.LayerNorm(text_latent_dim)

        init_networks(self, 'Text2LandmarkG', verbose=True)


    def forward(self, x, text):
        x_noise = x.view(x.size(0), -1, 1, 1)

        # B, D
        text_embed = self.text_linear(text)
        # text_embed = self.text_trans_encoder(text_embed)
        # text_embed = self.text_ln(text_embed)
        # B, D, 1, 1
        text_embed = text_embed.view(text_embed.size(0), -1, 1, 1)

        combined = torch.cat((x_noise, text_embed), dim=1)

        y = self.combined_up1(combined)
        y = self.combined_up2(y)
        y = self.combined_up3(y)
        y = self.combined_up4(y)
        y = self.combined_up5(y)

        # B, 64, 64, D
        y = y.permute(0, 2, 3, 1)

        return y


class Text2LandmarkD(nn.Module):
    def __init__(self, noise_dim=128, in_channels=68, ndf=32):
        super(Text2LandmarkD, self).__init__()
        self.conv1 = conv2x(in_channels, ndf)
        self.conv2 = conv2x(ndf, ndf * 2, bias=False)
        self.conv3 = conv2x(ndf * 2, ndf * 4, bias=False)
        self.conv4 = conv2x(ndf * 4, ndf * 4, bias=False)
        self.conv5 = conv1x(ndf * 8, ndf * 8, bias=False)
        self.conv6 = conv_output(ndf * 8, 1, bias=True)

        self.text_linear = linear(512, noise_dim)
        # self.text_trans_encoder = nn.TransformerEncoder(textTransEncoderLayer,
        #                                                 num_layers=num_text_layers)
        # self.text_ln = nn.LayerNorm(text_latent_dim)

        init_networks(self, 'Text2LandmarkD', verbose=True)

    def forward(self, x, text):
        x = x.permute(0, 3, 1, 2)
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.conv4(y)

        # B, D
        text_embed = self.text_linear(text)
        # text_embed = self.text_trans_encoder(text_embed)
        # text_embed = self.text_ln(text_embed)
        # B, D, 16
        text_embed = text_embed.unsqueeze(2).repeat(1, 1, 16)
        # B, D, 4, 4
        text_embed = text_embed.view(text_embed.size(0), text_embed.size(1), 4, 4)

        combined = torch.cat((y, text_embed), dim=1)
        y = self.conv5(combined)
        y = self.conv6(y)

        return y

In [38]:
def save_state(model_G, model_D, optimizer_G, optimizer_D, epoch_no, best=False):
    params = {'optimizer_G': optimizer_G.state_dict(),
              'optimizer_D': optimizer_D.state_dict(),
              'epoch': epoch_no}
    save_postfix = '_best' if best else f'_last'
    if isinstance(model_G, nn.DataParallel):
        torch.save(model_G.module.state_dict(), save_path + f"/checkpoints/model_G{save_postfix}.pth")
        torch.save(model_D.module.state_dict(), save_path + f"/checkpoints/model_D{save_postfix}.pth")
    else:
        torch.save(model_G.state_dict(), save_path + f"/checkpoints/model_G{save_postfix}.pth")
        torch.save(model_D.state_dict(), save_path + f"/checkpoints/model_D{save_postfix}.pth")
    torch.save(params, save_path + f"/checkpoints/params{save_postfix}.pth")

def heatmaps2keypoints(heatmaps, confidence):
    keypoints = []
    for k in range(heatmaps.shape[2]):
        heatmap_k = heatmaps[:, :, k]
        proba_max = np.max(heatmap_k)
        if proba_max > confidence:
            y, x = np.where(heatmap_k == proba_max)
            y, x = y[0], x[0]
        else:
            y, x = -1, -1
        keypoints.append((x, y))
    return np.int32(keypoints)

def keypoints2heatmaps(keypoints, size=(64, 64)):
    keypoints = keypoints.reshape(-1, 2).astype(np.int32)
    heatmaps = np.zeros(size + (keypoints.shape[0],), dtype=np.float32)
    for k in range(keypoints.shape[0]):
        x, y = keypoints[k]
        if x < 0 or y < 0:
            continue
        heatmaps[y, x, k] = 1
    return heatmaps

def visualize(fake_lmks, real_lmks, caption, epoch_no):
    fig = plt.figure(figsize=(16, 8))
#     fig.suptitle(caption)
    for i in range(16):
        kB = heatmaps2keypoints(fake_lmks[i], 0.2)
        fake_lmk = np.where(kB < 0, -1, kB * 4)
        kB = heatmaps2keypoints(real_lmks[i], 0.2)
        real_lmk = np.where(kB < 0, -1, kB * 4)

        plot_style = dict(marker='o',
                        markersize=4,
                        linestyle='-',
                        lw=2)

        pred_type = collections.namedtuple('prediction_type', ['slice', 'color'])
        pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)),
                    'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)),
                    'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)),
                    'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)),
                    'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)),
                    'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)),
                    'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)),
                    'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)),
                    'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4))
                    }

        ax = fig.add_subplot(4, 8, 2 * i + 1)
        ax.title.set_text('Fake')
        for pred_type in pred_types.values():
            ax.plot(fake_lmk[pred_type.slice, 0],
                    fake_lmk[pred_type.slice, 1],
                    color=pred_type.color, **plot_style)
        ax.axis('off')

        ax = fig.add_subplot(4, 8, 2 * i + 2)
        ax.title.set_text('Real')
        for pred_type in pred_types.values():
            ax.plot(real_lmk[pred_type.slice, 0],
                    real_lmk[pred_type.slice, 1],
                    color=pred_type.color, **plot_style)
        ax.axis('off')

    plt.show()
    fig.savefig(f"{save_path}/results/{epoch_no}.png")

def plot_loss(train_loss_epoch, valid_loss_epoch, train_loss, valid_loss, name='G'):
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(train_loss_epoch, train_loss, label=f'Train {name}')
    ax.plot(valid_loss_epoch, valid_loss, label=f'Valid {name}')
    ax.grid(True)
    plt.legend()
    plt.show()
    fig.savefig(f"{save_path}/loss_{name}.png")

In [39]:
batch_size_train = 64
batch_size_valid = 64
batch_size_test = 8
n_epoch = 50
valid_epoch_interval = 2
do_validation = False
start_epoch = 0
lambda_gp = 2
lr = 0.0001

save_path = f'e50-lgp2-lr0_0001-128-68-64-l36-cr3'

In [40]:
data_path = '/kaggle/working/text2landmark_68_2d/'

train_dataset = CelebALandmark(data_path, split=0)
print('>>> Training dataset length: {:d}'.format(train_dataset.__len__()))
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0, pin_memory=True)

if do_validation:
    valid_dataset = CelebALandmark(data_path, split=1)
    print('>>> Validation dataset length: {:d}'.format(valid_dataset.__len__()))
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size_valid, shuffle=True, num_workers=0, pin_memory=True)

# test_dataset = CelebALandmark(data_path, split=2)
# print('>>> Test dataset length: {:d}'.format(test_dataset.__len__()))
# test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=0, pin_memory=True)

>>> Training dataset length: 23698


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device: %s' % device)

if not os.path.exists(save_path):
    os.makedirs(save_path)
if not os.path.exists(f'{save_path}/results'):
    os.makedirs(f'{save_path}/results')
if not os.path.exists(f'{save_path}/checkpoints'):
    os.makedirs(f'{save_path}/checkpoints')

model_G = Text2LandmarkG(128, 68, 32)
model_D = Text2LandmarkD(128, 68, 32)
text_encoder = TextEncoder(device)

model_G.to(device)
model_D.to(device)

criterion = nn.MSELoss()

optimizer_G = Adam(model_G.parameters(), lr=lr, betas=(0, 0.9))
optimizer_D = Adam(model_D.parameters(), lr=lr, betas=(0, 0.9))

train_loss_G = []
train_loss_D = []
valid_loss_G = []
valid_loss_D = []
train_loss_epoch = []
valid_loss_epoch = []

for epoch_no in range(start_epoch, n_epoch):
    model_G.train()
    model_D.train()
    avg_loss_D = 0
    avg_loss_G = 0
    with tqdm(train_loader) as it:
        for batch_no, batch in enumerate(it, start=1):
            landmarks, captions, captions_1, captions_2 = batch
            b_size = len(landmarks)
            real_lmks = landmarks.to(device).type(torch.float)

            ###### Text Encoder ######
            # encoded_captions = text_encoder(captions, device).type(torch.float)
            # encoded_captions_1 = text_encoder(captions_1, device).type(torch.float)
            # encoded_captions_2 = text_encoder(captions_2, device).type(torch.float)
            encoded_captions = torch.zeros(size=(b_size, 512)).to(device).type(torch.float)
            encoded_captions_1 = torch.zeros(size=(b_size, 512)).to(device).type(torch.float)
            encoded_captions_2 = torch.zeros(size=(b_size, 512)).to(device).type(torch.float)

            ##### Discriminator ######
            # z = torch.normal(0, 1, size=(b_size, 128)).to(device)
            for _ in range(1):
                for p in model_D.parameters():
                    p.requires_grad = True
                for p in model_G.parameters():
                    p.requires_grad = False
                model_D.zero_grad()
                # real
                label = torch.full((b_size,), 1, dtype=torch.float, device=device)
                real_D = model_D(real_lmks, encoded_captions).view(-1)
                loss_real_D = criterion(real_D, label)

                # fake
                z = torch.normal(0, 1, size=(b_size, 128)).to(device)
                fake_lmks = model_G(z, encoded_captions)
                label = torch.full((b_size,), 0, dtype=torch.float, device=device)
                fake_D = model_D(fake_lmks, encoded_captions).view(-1)
                loss_fake_D = criterion(fake_D, label)

                # grad penalty
                alpha = torch.rand(b_size, 1, 1, 1).to(device)
                interp_lmks = alpha * real_lmks.data + (1- alpha) * fake_lmks.data
                interp_lmks = Variable(interp_lmks, requires_grad=True).to(device)
                interp_D = model_D(interp_lmks, encoded_captions)
                weight = torch.ones(interp_D.size()).to(device)
                gradients = torch.autograd.grad(outputs=interp_D,
                                                inputs=interp_lmks,
                                                grad_outputs=weight,
                                                retain_graph=True,
                                                create_graph=True,
                                                only_inputs=True)[0]
                gradients = gradients.view(gradients.size(0), -1)
                gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
                loss_gp_D = ((gradients_norm - 1) ** 2).mean()

                # discriminator loss
                loss_D = loss_real_D + loss_fake_D + loss_gp_D * lambda_gp
                # optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
            avg_loss_D += loss_D.item()

            ####### Generator ########
            for _ in range(1):
                for p in model_D.parameters():
                    p.requires_grad = False
                for p in model_G.parameters():
                    p.requires_grad = True
                model_G.zero_grad()
                z = torch.normal(0, 1, size=(b_size, 128)).to(device)
                fake_lmks = model_G(z, encoded_captions)
                fake_D = model_D(fake_lmks, encoded_captions).view(-1)
                label = torch.full((b_size,), 1, dtype=torch.float, device=device)
                loss_fake_G = criterion(fake_D, label)

                interp_captions = (encoded_captions_1 + encoded_captions_2) / 2
                z = torch.normal(0, 1, size=(b_size, 128)).to(device)
                fake_interp_lmks = model_G(z, interp_captions)
                fake_interp_D = model_D(fake_interp_lmks, interp_captions).view(-1)
                loss_interp_G = criterion(fake_interp_D, label)

                loss_G = loss_fake_G + loss_interp_G

                # optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()
            avg_loss_G += loss_G.item()

            it.set_postfix(
                ordered_dict={
                    "loss_D": avg_loss_D / batch_no,
                    "loss_G": avg_loss_G / batch_no,
                    "epoch": epoch_no,
                },
                refresh=False,
            )

    # rand_ind = np.random.randint(b_size)
    fake_lmks = fake_lmks[:16].detach().cpu().numpy()
    real_lmks = real_lmks[:16].detach().cpu().numpy()
    caption = captions[:16]
    visualize(fake_lmks, real_lmks, caption, epoch_no)

    train_loss_G.append(avg_loss_G / batch_no)
    train_loss_D.append(avg_loss_D / batch_no)
    train_loss_epoch.append(epoch_no)

    if do_validation and (epoch_no + 1) % valid_epoch_interval == 0:
        model_G.eval()
        model_D.eval()
        avg_loss_D_valid = 0
        avg_loss_G_valid = 0
        best_valid_loss = 0
        with torch.no_grad():
            with tqdm(valid_loader) as it:
                for batch_no, batch in enumerate(it, start=1):
                    landmarks, captions, captions_1, captions_2 = batch
                    b_size = len(landmarks)
                    real_lmks = landmarks.to(device).type(torch.float)

                    ###### Text Encoder ######
                    encoded_captions = text_encoder(captions, device).type(torch.float32)
                    encoded_captions_1 = text_encoder(captions_1, device).type(torch.float32)
                    encoded_captions_2 = text_encoder(captions_2, device).type(torch.float32)

                    ###### Discriminator ######
                    # real
                    label = torch.full((b_size,), 1, dtype=torch.float, device=device)
                    real_D = model_D(real_lmks, encoded_captions).view(-1)
                    loss_real_D = criterion(real_D, label)

                    # fake
                    z = torch.normal(0, 1, size=(b_size, 128)).to(device)
                    fake_lmks = model_G(z, encoded_captions).detach()
                    label = torch.full((b_size,), 0, dtype=torch.float, device=device)
                    fake_D = model_D(fake_lmks, encoded_captions).view(-1)
                    loss_fake_D = criterion(fake_D, label)

                    # grad penalty
                    # alpha = torch.rand(b_size, 1, 1, 1).to(device)
                    # interp_lmks = alpha * real_lmks.data + (1- alpha) * fake_lmks.data
                    # interp_lmks = Variable(interp_lmks, requires_grad=True).to(device)
                    # interp_D = model_D(interp_lmks, encoded_captions)
                    # weight = torch.ones(interp_D.size()).to(device)
                    # gradients = torch.autograd.grad(outputs=interp_D,
                    #                                 inputs=interp_lmks,
                    #                                 grad_outputs=weight,
                    #                                 retain_graph=True,
                    #                                 create_graph=True,
                    #                                 only_inputs=True)[0]
                    # gradients = gradients.view(gradients.size(0), -1)
                    # gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
                    # loss_gp_D = ((gradients_norm - 1) ** 2).mean()

                    # discriminator loss
                    loss_D = loss_real_D + loss_fake_D
                    avg_loss_D_valid += loss_D.item()

                    ###### Generator ######
                    z = torch.normal(0, 1, size=(b_size, 128)).to(device)
                    fake_lmks = model_G(z, encoded_captions)
                    fake_D = model_D(fake_lmks, encoded_captions).view(-1)
                    label = torch.full((b_size,), 1, dtype=torch.float, device=device)
                    loss_fake_G = criterion(fake_D, label)

                    interp_captions = (encoded_captions_1 + encoded_captions_2) / 2
                    z = torch.normal(0, 1, size=(b_size, 128)).to(device)
                    fake_interp_lmks = model_G(z, interp_captions)
                    fake_interp_D = model_D(fake_interp_lmks, interp_captions).view(-1)
                    loss_interp_G = criterion(fake_interp_D, label)

                    loss_G = loss_fake_G + loss_interp_G
                    # loss_G = (loss_fake_G + loss_interp_G) / 2

                    avg_loss_G_valid += loss_G.item()

                    it.set_postfix(
                        ordered_dict={
                            "loss_D": avg_loss_G_valid / batch_no,
                            "loss_G": avg_loss_D_valid / batch_no,
                            "epoch": epoch_no,
                        },
                        refresh=False,
                    )

        valid_loss_G.append(avg_loss_G_valid / batch_no)
        valid_loss_D.append(avg_loss_D_valid / batch_no)
        valid_loss_epoch.append(epoch_no)
        if best_valid_loss > (avg_loss_G_valid + avg_loss_D_valid) / 2:
            best_valid_loss = (avg_loss_G_valid + avg_loss_D_valid) / 2
            print(f"\n Best G loss is updated to {avg_loss_G_valid / batch_no} at {epoch_no}")
            print(f"Best D loss is updated to {avg_loss_D_valid / batch_no} at {epoch_no}")
            save_state(model_G, model_D, optimizer_G, optimizer_D, epoch_no, best=True)

    if (epoch_no + 1) == n_epoch:
        plot_loss(train_loss_epoch, valid_loss_epoch, train_loss_G, valid_loss_G, name='G')
        plot_loss(train_loss_epoch, valid_loss_epoch, train_loss_D, valid_loss_D, name='D')

    save_state(model_G, model_D, optimizer_G, optimizer_D, epoch_no)

In [None]:
def vis(real_lmks):
    fig = plt.figure(figsize=(16, 16))

    for i in range(16):
        kB = heatmaps2keypoints(real_lmks[i], 0.2)
        real_lmk = np.where(kB < 0, -1, kB * 4)

        ax = fig.add_subplot(8, 8, i + 1)
        ax.title.set_text('Real')
        ax.scatter(real_lmk[0, 0], real_lmk[0, 1])
        ax.axis('off')

    plt.show()

for batch in train_loader:
    landmarks, captions, captions_1, captions_2 = batch
    landmarks = landmarks.type(torch.float)
    print(len(landmarks), len(train_loader))
#     vis(landmarks[:64].detach().cpu().numpy())