In [1]:
import glob 
import os
import pickle
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

from torch.distributions import Laplace
import poisevae
from poisevae.datasets import CUB
from poisevae.utils import NN_lookup, Categorical, sent_emb
from poisevae.networks.CUBNetworks import EncImg, DecImg, EncTxt, DecTxt

from sklearn.cross_decomposition import CCA

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 20
plt.rcParams['font.weight'] = 'normal'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['text.usetex'] = False

## Declarations & Loading

In [2]:
device ='cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
HOME_PATH = os.path.expanduser('~')
DATA_PATH = os.path.join(HOME_PATH, 'Datasets/CUB/')

In [4]:
DATA_SIZE = 81920
true_img = torch.load('../../true_data_img.pt')[:DATA_SIZE]
true_txt = torch.load('../../true_data_txt.pt')[:DATA_SIZE]
true_img_pca = torch.load('../../true_data_img_pca.pt').cpu().numpy()[:DATA_SIZE]
true_sent_emb = torch.load('../../true_data_sent_embedding.pt').cpu().numpy()[:DATA_SIZE]
sent_PC = torch.load('../../sentence_emb_PC.pt').to(device, torch.float32)
img_PC = torch.load('../../image_PC.pt').to(device, torch.float32)

In [5]:
with open(os.path.join(DATA_PATH, 'cub/oc:3_msl:32/cub.emb'), 'rb') as file:
    emb = pickle.load(file)
with open(os.path.join(DATA_PATH, 'cub/oc:3_msl:32/cub.weights'), 'rb') as file:
    weights = pickle.load(file)

In [6]:
def pca_transform(X, PC):
    li = X.split(2048, 0)
    return torch.cat([e - torch.matmul(PC, e.unsqueeze(-1)).squeeze() for e in li])

In [7]:
def calculate_corr(imgs, txts, true_img_mean, true_txt_mean):
    if isinstance(imgs, np.ndarray):
        imgs = torch.from_numpy(imgs)
    if isinstance(txts, np.ndarray):
        txts = torch.from_numpy(txts)
    if isinstance(true_img_mean, np.ndarray):
        true_img_mean = torch.from_numpy(true_img_mean)
    if isinstance(true_txt_mean, np.ndarray):
        true_txt_mean = torch.from_numpy(true_txt_mean)
    # Assume all are projected
    corr = F.cosine_similarity((imgs.cpu() - true_img_mean.cpu()), 
                               (txts.cpu() - true_txt_mean.cpu())).mean()
    return corr

In [9]:
def perform_cca(gen_img=None, gen_txt_emb=None):
    if gen_img is not None:
        gen_img = torch.cat(gen_img).to(device, torch.float32)
        gen_img = pca_transform(gen_img, img_PC.to(device, torch.float32)).cpu().numpy()
    else:
        gen_img = true_img_pca
        
    if gen_txt_emb is not None: 
        gen_txt_emb = torch.from_numpy(np.vstack(gen_txt_emb)).to(device, torch.float32)
        gen_txt_emb = pca_transform(gen_txt_emb, sent_PC.to(device, torch.float32)).cpu().numpy()
    else:
        gen_txt_emb = true_sent_emb
    
    gen_img_cca, gen_sent_cca = cca.transform(gen_img, gen_txt_emb)
    return float(calculate_corr(gen_img_cca, gen_sent_cca, true_img_cca_mean, true_sent_cca_mean))

def eval_model(vae):
    corr = []
    for mode in ('joint', 'i2s', 's2i'):
        gen_txt_emb = []
        gen_img = []
        with torch.no_grad():
            for i in range(len(true_img)):
                if mode == 'i2s':
                    results = vae([true_img[i], None], n_gibbs_iter=50)
                    gen_txt_i = results['x_rec'][1].probs.argmax(dim=1).reshape(true_img[i].shape[0], -1)
                    gen_txt_emb.append(np.zeros((true_img[i].shape[0], emb.shape[1])))
                    sent_emb(gen_txt_i.cpu().numpy().astype(np.int32), emb, weights, gen_txt_emb[-1])
                elif mode == 's2i':
                    results = vae([None, true_txt[i]], n_gibbs_iter=50)
                    gen_img.append(results['x_rec'][0].loc)
                elif mode == 'joint':
                    results = vae.generate(true_img[i].shape[0], n_gibbs_iter=50)
                    gen_img.append(results['x_rec'][0].loc)
                    gen_txt_i = results['x_rec'][1].probs.argmax(dim=1).reshape(true_img[i].shape[0], -1)
                    gen_txt_emb.append(np.zeros((true_img[i].shape[0], emb.shape[1])))
                    sent_emb(gen_txt_i.cpu().numpy().astype(np.int32), emb, weights, gen_txt_emb[-1])
                else: 
                    raise ValueError

            if mode == 'i2s':
                corr.append(('i2s', perform_cca(gen_txt_emb=gen_txt_emb)))
            elif mode == 's2i':
                corr.append(('s2i', perform_cca(gen_img=gen_img)))
            elif mode == 'joint':
                corr.append(('joint', perform_cca(gen_img=gen_img, gen_txt_emb=gen_txt_emb)))
            else: 
                raise ValueError 
        del results, gen_img, gen_txt_emb # Free memory
    return corr

## CCA over truths

In [10]:
try:
    with open('CCA_model.pkl', 'rb') as f:
        cca = pickle.load(f)
except FileNotFoundError:
    cca = CCA(n_components=10, tol=1e-4)
    cca.fit(true_img_pca, true_sent_emb)
    with open('CCA_model.pkl','wb') as f:
        pickle.dump(cca, f)

In [11]:
true_img_cca, true_sent_cca = cca.transform(true_img_pca, true_sent_emb)
true_img_cca_mean, true_sent_cca_mean =  true_img_cca.mean(axis=0), true_sent_cca.mean(axis=0)
calculate_corr(true_img_cca, true_sent_cca, true_img_cca_mean, true_sent_cca_mean)

tensor(0.4875, dtype=torch.float64)

### Check on convergence

In [12]:
idx = np.arange(true_img_pca.shape[0])
np.random.shuffle(idx)
true_img_cca_hat, true_sent_cca_hat = cca.transform(true_img_pca[idx], true_sent_emb[idx])
calculate_corr(true_img_cca_hat, true_sent_cca_hat, true_img_cca_mean, true_sent_cca_mean)

tensor(0.4875, dtype=torch.float64)

### Check on overfitting

In [13]:
# test_true_img_pca = torch.load('true_data_img_pca.pt').cpu().numpy()[DATA_SIZE:DATA_SIZE+1000]
# test_true_sent_emb = torch.load('true_data_sent_embedding.pt').cpu().numpy()[DATA_SIZE:DATA_SIZE+1000]
# test_true_img_cca_hat, test_true_sent_cca_hat = cca.transform(test_true_img_pca, test_true_sent_emb)
# calculate_corr(test_true_img_cca_hat, test_true_sent_cca_hat, true_img_cca_mean, true_sent_cca_mean)

## CCA over trained model

### Loading checkpoints

In [14]:
condition = 'worew' # without reweighting
enc_img = EncImg(128).to(device, torch.float32)
dec_img = DecImg(128).to(device, torch.float32)
enc_txt = EncTxt(1590, 128).to(device, torch.float32)
dec_txt = DecTxt(1590, 128).to(device, torch.float32)
vae = poisevae.POISEVAE([enc_img, enc_txt], [dec_img, dec_txt], likelihoods=[Laplace, Categorical], 
                        latent_dims=[128, (128, 1, 1)], batch_size=2048, device=device)

In [15]:
true_img = true_img.to(device, torch.float32).split(2048, dim=0)
true_txt = true_txt.to(device, torch.float32).split(2048, dim=0)

In [16]:
corr = {'model': [], 'mode': [], 'correlation': []}

In [None]:
for _ in tqdm(range(5)):
    for modname, modelname in zip(('worew', 'wrew'), ('POISE-VAE', 'POISE-VAE*')):
        for path in glob.glob('../example/runs/CUB/%s/*' % modname):
            try:
                vae, _, _ = poisevae.utils.load_checkpoint(vae, load_path=os.path.join(path, 'training_50.pt'))
            except FileNotFoundError:
                continue
            results = eval_model(vae)
            for mode, val in results:
                corr['model'].append(modelname)
                corr['mode'].append(mode)
                corr['correlation'].append(val)

In [35]:
corr = pd.DataFrame(corr)

In [44]:
corr.to_csv('CCA_poise.csv')

In [47]:
corr.groupby(['model', 'mode']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,correlation
model,mode,Unnamed: 2_level_1
POISE-VAE,i2s,0.349495
POISE-VAE,joint,0.335744
POISE-VAE,s2i,0.143429
POISE-VAE*,i2s,0.033019
POISE-VAE*,joint,0.031961
POISE-VAE*,s2i,0.065847


In [48]:
corr.groupby(['model', 'mode']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,correlation
model,mode,Unnamed: 2_level_1
POISE-VAE,i2s,0.003378
POISE-VAE,joint,0.03878
POISE-VAE,s2i,0.024031
POISE-VAE*,i2s,0.002069
POISE-VAE*,joint,0.011168
POISE-VAE*,s2i,0.001182
