In [1]:
import glob 
import os
import pickle
import numpy as np
import torch
from torch.nn import functional as F
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

from torch.distributions import Laplace
import poisevae
from poisevae.datasets import CUB
from poisevae.utils import NN_lookup, Categorical, sent_emb
from poisevae.networks.CUBNetworks import EncImg, DecImg, EncTxt, DecTxt

from sklearn.cross_decomposition import CCA

# from cca import *

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 20
plt.rcParams['font.weight'] = 'normal'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['text.usetex'] = False

## Declarations & Loading

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
HOME_PATH = os.path.expanduser('~')
DATA_PATH = os.path.join(HOME_PATH, 'Datasets/CUB/')

In [4]:
DATA_SIZE = 81920
true_img = torch.load('true_data_img.pt')[:DATA_SIZE]
true_txt = torch.load('true_data_txt.pt')[:DATA_SIZE]
true_img_pca = torch.load('true_data_img_pca.pt').cpu().numpy()[:DATA_SIZE]
true_sent_emb = torch.load('true_data_sent_embedding.pt').cpu().numpy()[:DATA_SIZE]
sent_PC = torch.load('sentence_emb_PC.pt').to(device, torch.float32)
img_PC = torch.load('image_PC.pt').to(device, torch.float32)

In [5]:
with open(os.path.join(DATA_PATH, 'cub/oc:3_msl:32/cub.emb'), 'rb') as file:
    emb = pickle.load(file)
with open(os.path.join(DATA_PATH, 'cub/oc:3_msl:32/cub.weights'), 'rb') as file:
    weights = pickle.load(file)

In [6]:
def pca_transform(X, PC):
    li = X.split(2048, 0)
    return torch.cat([e - torch.matmul(PC, e.unsqueeze(-1)).squeeze() for e in li])

In [7]:
def calculate_corr(imgs, txts, true_img_mean, true_txt_mean):
    if isinstance(imgs, np.ndarray):
        imgs = torch.from_numpy(imgs)
    if isinstance(txts, np.ndarray):
        txts = torch.from_numpy(txts)
    if isinstance(true_img_mean, np.ndarray):
        true_img_mean = torch.from_numpy(true_img_mean)
    if isinstance(true_txt_mean, np.ndarray):
        true_txt_mean = torch.from_numpy(true_txt_mean)
    # Assume all are projected
    corr = F.cosine_similarity((imgs.cpu() - true_img_mean.cpu()), 
                               (txts.cpu() - true_txt_mean.cpu())).mean()
    return corr

## CCA over truths

In [8]:
try:
    with open('CCA_model.pkl', 'rb') as f:
        cca = pickle.load(f)
except FileNotFoundError:
    cca = CCA(n_components=10, tol=1e-4)
    cca.fit(true_img_pca, true_sent_emb)
    with open('CCA_model.pkl','wb') as f:
        pickle.dump(cca, f)

In [9]:
true_img_cca, true_sent_cca = cca.transform(true_img_pca, true_sent_emb)
true_img_cca_mean, true_sent_cca_mean =  true_img_cca.mean(axis=0), true_sent_cca.mean(axis=0)
calculate_corr(true_img_cca, true_sent_cca, true_img_cca_mean, true_sent_cca_mean)

tensor(0.4875, dtype=torch.float64)

### Check on convergence

In [10]:
idx = np.arange(true_img_pca.shape[0])
np.random.shuffle(idx)
true_img_cca_hat, true_sent_cca_hat = cca.transform(true_img_pca[idx], true_sent_emb[idx])
calculate_corr(true_img_cca_hat, true_sent_cca_hat, true_img_cca_mean, true_sent_cca_mean)

tensor(0.4875, dtype=torch.float64)

### Check on overfitting

In [11]:
test_true_img_pca = torch.load('true_data_img_pca.pt').cpu().numpy()[DATA_SIZE:DATA_SIZE+1000]
test_true_sent_emb = torch.load('true_data_sent_embedding.pt').cpu().numpy()[DATA_SIZE:DATA_SIZE+1000]
test_true_img_cca_hat, test_true_sent_cca_hat = cca.transform(test_true_img_pca, test_true_sent_emb)
calculate_corr(test_true_img_cca_hat, test_true_sent_cca_hat, true_img_cca_mean, true_sent_cca_mean)

tensor(0.4462, dtype=torch.float64)

## CCA over trained model

### Loading checkpoints

In [12]:
enc_img = EncImg(128).to(device)
dec_img = DecImg(128).to(device)
enc_txt = EncTxt(1590, 128).to(device)
dec_txt = DecTxt(1590, 128).to(device)
vae = poisevae.POISEVAE([enc_img, enc_txt], [dec_img, dec_txt], likelihoods=[Laplace, Categorical], 
                        latent_dims=[128, (128, 1, 1)]).to(device)
vae, _, epoch = poisevae.utils.load_checkpoint(vae, load_path=sorted(glob.glob('../example/runs/CUB/wrew/train*.pt'))[-1])
epoch

50

### Reconstruction

In [13]:
rec_txt_emb = []
rec_img = []
with torch.no_grad():
    true_img_split = true_img.to(device, torch.float32).split(2048, dim=0)
    true_txt_split = true_txt.to(device, torch.float32).split(2048, dim=0)
    
    for i in range(len(true_img_split)):
        results = vae([true_img_split[i], true_txt_split[i]])

        rec_img_i = results['x_rec'][0].loc
        rec_img.append(rec_img_i)
        
        rec_txt_i = results['x_rec'][1].probs.argmax(dim=1).reshape(rec_img_i.shape[0], -1)
        rec_txt_emb.append(np.zeros((rec_img_i.shape[0], emb.shape[1])))
        sent_emb(rec_txt_i.cpu().numpy().astype(np.int32), emb, weights, rec_txt_emb[-1])
            
rec_img = torch.cat(rec_img).to(device, torch.float32)
rec_txt_emb = torch.from_numpy(np.vstack(rec_txt_emb)).to(device, torch.float32)
true_img_split, true_txt_split = None, None # Free CUDA memory

In [14]:
rec_txt_emb_pca = pca_transform(rec_txt_emb, sent_PC.to(device, torch.float32)).cpu().numpy()
rec_img_pca = pca_transform(rec_img, img_PC.to(device, torch.float32)).cpu().numpy()

In [15]:
rec_img_cca, rec_sent_cca = cca.transform(rec_img_pca, rec_txt_emb_pca)
calculate_corr(rec_img_cca, rec_sent_cca, true_img_cca_mean, true_sent_cca_mean)

tensor(0.3894, dtype=torch.float64)