# PCA Analysis for `anaphora-1` models

In [None]:
import os
import sys

import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
from hydra import compose, initialize, initialize_config_dir, initialize_config_module
from hydra.utils import instantiate
from matplotlib import pyplot as plt
from omegaconf import OmegaConf, open_dict
from pytorch_lightning import LightningDataModule, LightningModule
from sklearn.decomposition import PCA

sys.path.append("../")

In [None]:
def get_model_from_ckpt(exp_path):
    config_name = "config"
    wandb_path = "wandb/latest-run/files/"
    exp_dir = os.path.abspath(os.path.join("../", exp_path))
    ckpt_dir = os.path.abspath(os.path.join("../", exp_path, "checkpoints"))
    ckpt_path = os.path.join(ckpt_dir, "last.ckpt")
    saved_wandb_dir = os.path.abspath(os.path.join("../", exp_path, wandb_path))
    saved_cfg_dir = os.path.join(exp_dir, ".hydra")

    assert os.path.exists(f"{saved_cfg_dir}/{config_name}.yaml")
    assert os.path.exists(f"{saved_wandb_dir}/{config_name}.yaml")

    cfgs = {}

    with initialize_config_dir(version_base="1.1", config_dir=saved_cfg_dir):
        cfg = compose(config_name=config_name)
        cfgs["hydra"] = cfg

    with initialize_config_dir(version_base="1.1", config_dir=saved_wandb_dir):
        cfg = compose(config_name=config_name)
        cfgs["wandb"] = cfg

    model = create_model(cfgs)
    model = model.__class__.load_from_checkpoint(ckpt_path)
    model.eval()

    datamodule = create_datamodule(cfgs)

    return model, datamodule


def create_datamodule(cfgs):
    datamodule_cfg = cfgs["hydra"].datamodule
    data_dir = cfgs["wandb"]["datamodule/data_dir"].value

    with open_dict(datamodule_cfg):
        datamodule_cfg.data_dir = data_dir

    datamodule: LightningDataModule = instantiate(datamodule_cfg)
    return datamodule


def create_model(cfgs):
    model_cfg = cfgs["hydra"].model
    dec_vocab_size = cfgs["wandb"]["model/dec_vocab_size"].value
    enc_vocab_size = cfgs["wandb"]["model/enc_vocab_size"].value
    dec_EOS_idx = cfgs["wandb"]["model/dec_EOS_idx"].value
    with open_dict(model_cfg):
        model_cfg.dec_vocab_size = dec_vocab_size
        model_cfg.enc_vocab_size = enc_vocab_size
        model_cfg.dec_EOS_idx = dec_EOS_idx

    model: LightningModule = instantiate(model_cfg)

    return model


model, datamodule = get_model_from_ckpt(
    # exp_path="outputs/anaphora-1/2022-09-12_16-14-42"
    exp_path="outputs/SCAN (Add Jump)/2022-10-02_11-43-28"
)
datamodule.setup()

In [None]:
# Iterate through the entire dataset.
# Compute the last-state encodings of each input,
# and stack the results into a (D,H)-sized tensor,
# where D = length of dataset and H = encoder hidden
# size.
#
# Perform k=2 PCA on this to create a (D,2)-sized tensor
# for analysis

data_encodings = None
data_inputs = None
data_preds = None
data_target = None

for dl in [
    datamodule.train_dataloader(),
    datamodule.val_dataloader(),
    datamodule.test_dataloader(),
    # datamodule.gen_dataloader(),
]:
    for batch in dl:

        with torch.no_grad():
            batch_enc = model.get_encodings(batch)["encoder_last_state"]
            _, preds, target = model.step(batch)

        if data_encodings is not None:

            data_encodings = torch.cat((data_encodings, batch_enc), dim=0)
            data_preds = torch.cat((data_preds, preds), dim=0)
            data_target = torch.cat((data_target, target), dim=0)

            # print(data_inputs.shape, batch[0].shape)

            # Pad input tensors if lengths are wrong
            i_size = max(data_inputs.shape[1], batch[0].shape[1])
            i_delta = i_size - min(data_inputs.shape[1], batch[0].shape[1])
            if i_size > data_inputs.shape[1]:
                data_inputs = F.pad(data_inputs, (0, i_delta), "constant", 0)
                batch_0 = batch[0]
                # print(f"batch[0] was bigger, so i'm padding data_inputs from {data_inputs.shape[1]} to {i_size}")
                # print("data_inputs:", data_inputs.shape)
            elif i_size > batch[0].shape[1]:
                # print(f"data_inputs was bigger, so i'm padding batch[0] from {batch[0].shape[1]} to {i_size}")
                batch_0 = F.pad(batch[0], (0, i_delta), "constant", 0)
                # print("batch[0]:", batch_0.shape)
            else:
                # print("Equal! No padding required")
                batch_0 = batch[0]

            data_inputs = torch.cat((data_inputs, batch_0), dim=0)

            # # pad tensors if length is wrong
            # for (d, t) in [(data_inputs, batch[0]), (data_preds, preds), (data_target, target)]:
            #     d_t_size = max(d.shape[1], t.shape[1])
            #     if d_t_size > d.shape[1]:
            #         d = F.pad(d, (0, d_t_size), "constant", 0)
            #     elif d_t_size > t.shape[1]:
            #         t = F.pad(d, (0, d_t_size), "constant", 0)

            # data_inputs = torch.cat((data_inputs, batch[0]), dim=0)

        else:
            data_encodings = batch_enc
            data_inputs = batch[0]
            data_preds = preds
            data_target = target

if len(data_encodings.shape) > 2:
    if data_encodings.shape[1] > 1:
        # only look at the last layer
        data_encodings = data_encodings[:, 1, :]
    data_encodings = torch.squeeze(data_encodings)

# data_encodings = torch.squeeze(data_encodings)

i_labels = [
    datamodule.data_train.dataset.convert_tokens_to_string(k, col="source")
    for _, k in enumerate(data_inputs)
]
i_labels = [" ".join(l) for l in i_labels]

t_labels = [
    datamodule.data_train.dataset.convert_tokens_to_string(k, col="target")
    for _, k in enumerate(data_target)
]
t_labels = [" ".join(l) for l in t_labels]

p_labels = [
    datamodule.data_train.dataset.convert_tokens_to_string(k, col="target")
    for _, k in enumerate(data_preds)
]
p_labels = [" ".join(l) for l in p_labels]

pt_pca = torch.pca_lowrank(data_encodings, q=2)
pt_reduced_enc = (data_encodings @ pt_pca[2]).detach()

df = pd.DataFrame(
    {
        "input": i_labels,
        "target": t_labels,
        "prediction": p_labels,
        "pca1": pt_reduced_enc[:, 0],
        "pca2": pt_reduced_enc[:, 1],
    }
)

In [None]:
df["jump"] = df["input"].str.contains("jump")
df["walk"] = df["input"].str.contains("walk")
df["turn_right"] = df["input"].str.contains("turn right")
df["twice"] = df["input"].str.contains("twice")
df["jump_twice"] = df["jump"] & df["twice"]
df["walk_twice"] = df["walk"] & df["twice"]
df["turn_right_twice"] = df["turn_right"] & df["twice"]

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="jump", fit_reg=False)
lm.fig.suptitle("PCA of SRN for SCAN (Add Jump)")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="walk", fit_reg=False)
lm.fig.suptitle("PCA of SRN for SCAN (Add Jump)")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="twice", fit_reg=False)
lm.fig.suptitle("PCA of SRN for SCAN (Add Jump)")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="jump_twice", fit_reg=False)
lm.fig.suptitle("PCA of SRN for SCAN (Add Jump)")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="walk_twice", fit_reg=False)
lm.fig.suptitle("PCA of SRN for SCAN (Add Jump)")

In [None]:
df["himself"] = df["input"].str.contains("himself")
df["herself"] = df["input"].str.contains("herself")
df["refl"] = df["himself"] | df["herself"]
df["alice"] = df["input"].str.contains("alice")
df["bob"] = df["input"].str.contains("bob")
df["claire"] = df["input"].str.contains("claire")
df["knows"] = df["input"].str.contains("knows")
df["likes"] = df["input"].str.contains("likes")
df["sees"] = df["input"].str.contains("sees")
df["alice_refl"] = df["alice"] & df["refl"]
df["claire_refl"] = df["claire"] & df["refl"]
df["intrans"] = df["input"].str.contains("<PAD>")


int_to_refl = {0: "non-reflexive", 1: "herself", 2: "himself"}
df["refl_type"] = df["herself"].apply(int)
df["refl_type"] += df["himself"].apply(int).apply(lambda x: 2 * x)
df["refl_type"] = df["refl_type"].apply(lambda x: int_to_refl[x])

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="turn_right", fit_reg=False)
lm.fig.suptitle("PCA of SRN for SCAN (Add Jump)")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="turn_right_twice", fit_reg=False)
lm.fig.suptitle("PCA of SRN for SCAN (Add Jump)")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="intrans", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="refl_type", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="alice", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="bob", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="knows", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="likes", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="sees", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="alice_refl", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="claire_refl", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")

In [None]:
lm = sns.lmplot("pca1", "pca2", data=df, hue="claire", fit_reg=False)
lm.fig.suptitle("PCA of 49-dim SRN")