# PCA Analysis for `anaphora-1` models

The goal for this notebook is to perform analysis on the learned encoding space for SRN/GRU models which can successfully solve the `anaphora-1` task.
- Specifically, we are interested in how reflexive inputs (e.g., "Alice sees herself") are represented in the encoder's hidden space, compared to pseudo reflexive inputs (e.g., "Alice sees Alice") and regular transitive expressions (e.g., "Alice sees Bob")

In [None]:
import sys
sys.path.append("../")

import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf, open_dict

import pandas as pd
import seaborn as sns
import torch
from sklearn.decomposition import PCA
from pytorch_lightning import LightningModule, LightningDataModule

from matplotlib import pyplot as plt

In [None]:
def get_model_from_ckpt(exp_path):
    config_name = "config"
    wandb_path = "wandb/latest-run/files/"
    exp_dir = os.path.abspath(os.path.join("../",exp_path))
    ckpt_dir = os.path.abspath(os.path.join("../", exp_path, "checkpoints"))
    ckpt_path = os.path.join(ckpt_dir, 'epoch_448.ckpt')
    saved_wandb_dir = os.path.abspath(os.path.join("../", exp_path, wandb_path))
    saved_cfg_dir = os.path.join(exp_dir, ".hydra")

    assert os.path.exists(f"{saved_cfg_dir}/{config_name}.yaml")
    assert os.path.exists(f"{saved_wandb_dir}/{config_name}.yaml")

    cfgs = {}

    with initialize_config_dir(version_base = "1.1", config_dir=saved_cfg_dir):
        cfg = compose(config_name=config_name)
        cfgs["hydra"] = cfg
    
    with initialize_config_dir(version_base = "1.1", config_dir=saved_wandb_dir):
        cfg = compose(config_name=config_name)
        cfgs["wandb"] = cfg
    
    model = create_model(cfgs)
    model.load_from_checkpoint(ckpt_path)

    datamodule = create_datamodule(cfgs)

    return model, datamodule

def create_datamodule(cfgs):
    datamodule_cfg = cfgs["hydra"].datamodule
    data_dir = cfgs["wandb"]["datamodule/data_dir"].value

    with open_dict(datamodule_cfg):
        datamodule_cfg.data_dir = data_dir

    datamodule: LightningDataModule = instantiate(datamodule_cfg)
    return datamodule

def create_model(cfgs):
    model_cfg = cfgs["hydra"].model
    dec_vocab_size = cfgs["wandb"]["model/dec_vocab_size"].value
    enc_vocab_size = cfgs["wandb"]["model/enc_vocab_size"].value
    dec_EOS_idx = cfgs["wandb"]["model/dec_EOS_idx"].value
    with open_dict(model_cfg):
        model_cfg.dec_vocab_size = dec_vocab_size
        model_cfg.enc_vocab_size = enc_vocab_size
        model_cfg.dec_EOS_idx = dec_EOS_idx
    
    model: LightningModule = instantiate(model_cfg)

    return model
    
model, datamodule = get_model_from_ckpt(exp_path="outputs/anaphora-1/2022-09-12_16-14-42")
datamodule.setup()

In [None]:

# Iterate through the entire dataset. 
# Compute the last-state encodings of each input,
# and stack the results into a (D,H)-sized tensor,
# where D = length of dataset and H = encoder hidden
# size.
#
# Perform k=2 PCA on this to create a (D,2)-sized tensor
# for analysis

data_encodings = None
data_inputs = None
for dl in [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader(), datamodule.gen_dataloader()]:
    for batch in dl:
        batch_enc = model.get_encodings(batch)['encoder_last_state']

        if data_encodings is not None:
            data_encodings = torch.cat((data_encodings, batch_enc), dim=0)
            data_inputs = torch.cat((data_inputs, batch[0]), dim=0)
        else:
            data_encodings = batch_enc
            data_inputs = batch[0]

data_encodings = torch.squeeze(data_encodings)

labels = [datamodule.data_train.dataset.convert_tokens_to_string(k, col='source') for _, k in enumerate(data_inputs)]
labels = [' '.join(l) for l in labels]

pt_pca = torch.pca_lowrank(data_encodings, q=2)
pt_reduced_enc = (data_encodings @ pt_pca[2]).detach()

df = pd.DataFrame({
    'input': labels,
    'pca1': pt_reduced_enc[:,0],
    'pca2': pt_reduced_enc[:,1]
})

In [None]:
for batch in datamodule.gen_dataloader():
    loss, preds, target = model.step(batch)
    p_labels = [' '.join(datamodule.data_train.dataset.convert_tokens_to_string(k, col='target')) for _, k in enumerate(preds)]
    t_labels = [' '.join(datamodule.data_train.dataset.convert_tokens_to_string(k, col='target')) for _, k in enumerate(target)]
    
    print(preds)
    for i, p in enumerate(p_labels):
        print(p, "***", t_labels[i])
    raise SystemExit

In [None]:
df['himself'] = df['input'].str.contains("himself")
df['herself'] = df['input'].str.contains("herself")
df['refl'] = df['himself'] | df['herself']
df['alice'] = df['input'].str.contains("alice")
df['bob'] = df['input'].str.contains("bob")
df['claire'] = df['input'].str.contains("claire")
df['knows'] = df['input'].str.contains("knows")
df['likes'] = df['input'].str.contains("likes")
df['sees'] = df['input'].str.contains("sees")
df['alice_refl'] = df['alice'] & df['refl']
df['claire_refl'] = df['claire'] & df['refl']
df['intrans'] = df['input'].str.contains("<PAD>")


int_to_refl = {0: 'non-reflexive', 1: 'herself', 2: 'himself'}
df['refl_type'] = df['herself'].apply(int)
df['refl_type'] += df['himself'].apply(int).apply(lambda x: 2*x)
df['refl_type'] = df['refl_type'].apply(lambda x: int_to_refl[x])

In [None]:
# df.plot.scatter(x='pca1', y='pca2', c='refl', colormap='viridis')
lm = sns.lmplot('pca1', 'pca2', data=df[df['pca1'] < 0], hue='refl_type', fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN, Excluding Intransitives")

In [None]:
lm = sns.lmplot('pca1', 'pca2', data=df[df['pca1'] < 0], hue='alice', fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN, Excluding Intransitives")

In [None]:
lm = sns.lmplot('pca1', 'pca2', data=df, hue='intrans', fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
df[df['refl'] == True][df['pca2'] > -1*(0.5*df['pca1'] + 27.0/20)]

In [None]:
lm = sns.lmplot('pca1', 'pca2', data=df[df['pca1'] < 0], hue='bob', fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN, Excluding Intransitives")

In [None]:
lm = sns.lmplot('pca1', 'pca2', data=df[df['pca1'] < 0], hue='knows', fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN, Excluding Intransitives")

In [None]:
lm = sns.lmplot('pca1', 'pca2', data=df[df['pca1'] < 0], hue='likes', fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN, Excluding Intransitives")

In [None]:
lm = sns.lmplot('pca1', 'pca2', data=df[df['pca1'] < 0], hue='sees', fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN, Excluding Intransitives")

In [None]:
df['refl_type']