In [23]:
import glob 
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F  # activation function
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.linalg import eig

import poisevae
from poisevae.datasets import CUB
from poisevae.utils import NN_lookup

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 20
plt.rcParams['font.weight'] = 'normal'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['text.usetex'] = False

In [24]:
def cca(views, k=None, eps=1e-12):
    """Compute (multi-view) CCA
    Args:
        views (list): list of views where each view `v_i` is of size `N x o_i`
        k (int): joint projection dimension | if None, find using Otsu
        eps (float): regulariser [default: 1e-12]
    Returns:
        correlations: correlations along each of the k dimensions
        projections: projection matrices for each view
    """
    V = len(views)  # number of views
    N = views[0].size(0)  # number of observations (same across views)
    os = [v.size(1) for v in views]
    kmax = np.min(os)
    ocum = np.cumsum([0] + os)
    os_sum = sum(os)
    A, B = np.zeros([os_sum, os_sum]), np.zeros([os_sum, os_sum])

    for i in range(V):
        v_i = views[i]
        v_i_bar = v_i - v_i.mean(0).expand_as(v_i)  # centered, N x o_i
        C_ij = (1.0 / (N - 1)) * torch.mm(v_i_bar.t(), v_i_bar)
        # A[ocum[i]:ocum[i + 1], ocum[i]:ocum[i + 1]] = C_ij
        B[ocum[i]:ocum[i + 1], ocum[i]:ocum[i + 1]] = C_ij
        for j in range(i + 1, V):
            v_j = views[j]  # N x o_j
            v_j_bar = v_j - v_j.mean(0).expand_as(v_j)  # centered
            C_ij = (1.0 / (N - 1)) * torch.mm(v_i_bar.t(), v_j_bar)
            A[ocum[i]:ocum[i + 1], ocum[j]:ocum[j + 1]] = C_ij
            A[ocum[j]:ocum[j + 1], ocum[i]:ocum[i + 1]] = C_ij.t()

    A[np.diag_indices_from(A)] += eps
    B[np.diag_indices_from(B)] += eps

    eigenvalues, eigenvectors = eig(A, B)
    # TODO: sanity check to see that all eigenvalues are e+0i
    idx = eigenvalues.argsort()[::-1]  # sort descending
    eigenvalues = eigenvalues[idx]  # arrange in descending order

    if k is None:
        t = threshold(eigenvalues.real[:kmax])
        k = np.abs(np.asarray(eigenvalues.real[0::10]) - t).argmin() * 10  # closest k % 10 == 0 idx
        print('k unspecified, (auto-)choosing:', k)

    eigenvalues = eigenvalues[idx[:k]]
    eigenvectors = eigenvectors[:, idx[:k]]

    correlations = torch.from_numpy(eigenvalues.real).type_as(views[0])
    proj_matrices = torch.split(torch.from_numpy(eigenvectors.real).type_as(views[0]), os)

    return correlations, proj_matrices

In [42]:
true_img_ft = torch.load('true_data_img_feature.pt').to('cpu', dtype=torch.float32)
true_sent_emb = torch.load('true_data_sent_embedding.pt').to('cpu', dtype=torch.float32)

In [43]:
corr, (img_cca, txt_cca) = cca([true_img_ft, true_sent_emb], k=40)

In [44]:
torch.save(img_cca, 'true_data_img_feature_cca.pt')
torch.save(txt_cca, 'true_data_sent_embedding_cca.pt')

In [45]:
def calculate_corr(imgs, txts, true_img_mean, true_txt_mean, img_cca, txt_cca):
    corr = F.cosine_similarity((imgs - true_img_mean) @ img_cca,
                                (txts - true_txt_mean) @ txt_cca).mean()
    return corr

In [46]:
true_img_mean = true_img_ft.mean(dim=0)
true_txt_mean = true_sent_emb.mean(dim=0)

true_corr = calculate_corr(true_img_ft, true_sent_emb, true_img_mean, true_txt_mean, img_cca, txt_cca)

In [47]:
true_corr

tensor(0.1393)