#### Helpers

In [3]:
# Running covariance computation

import numpy as np


class OnlineCovariance:
    """
    A class to calculate the mean and the covariance matrix
    of the incrementally added, n-dimensional data.
    """

    def __init__(self, order):
        """
        Parameters
        ----------
        order: int, The order (=="number of features") of the incrementally added
        dataset and of the resulting covariance matrix.
        """
        self._order = order
        self._shape = (order, order)
        self._identity = np.identity(order)
        self._ones = np.ones(order)
        self._count = 0
        self._mean = np.zeros(order)
        self._cov = np.zeros(self._shape)

    @property
    def count(self):
        """
        int, The number of observations that has been added
        to this instance of OnlineCovariance.
        """
        return self._count

    @property
    def mean(self):
        """
        double, The mean of the added data.
        """
        return self._mean

    @property
    def cov(self):
        """
        array_like, The covariance matrix of the added data.
        """
        return self._cov

    @property
    def corrcoef(self):
        """
        array_like, The normalized covariance matrix of the added data.
        Consists of the Pearson Correlation Coefficients of the data's features.
        """
        if self._count < 1:
            return None
        variances = np.diagonal(self._cov)
        denomiator = np.sqrt(variances[np.newaxis, :] * variances[:, np.newaxis])
        return self._cov / denomiator

    def add(self, observation):
        """
        Add the given observation to this object.

        Parameters
        ----------
        observation: array_like, The observation to add.
        """
        if self._order != len(observation):
            raise ValueError(f"Observation to add must be of size {self._order}")

        self._count += 1
        delta_at_nMin1 = np.array(observation - self._mean)
        self._mean += delta_at_nMin1 / self._count
        weighted_delta_at_n = np.array(observation - self._mean) / self._count
        shp = (self._order, self._order)
        D_at_n = np.broadcast_to(weighted_delta_at_n, self._shape).T
        D = (delta_at_nMin1 * self._identity).dot(D_at_n.T)
        self._cov = self._cov * (self._count - 1) / self._count + D

    def merge(self, other):
        """
        Merges the current object and the given other object into a new OnlineCovariance object.

        Parameters
        ----------
        other: OnlineCovariance, The other OnlineCovariance to merge this object with.

        Returns
        -------
        OnlineCovariance
        """
        if other._order != self._order:
            raise ValueError(
                f"""
                   Cannot merge two OnlineCovariances with different orders.
                   ({self._order} != {other._order})
                   """
            )

        merged_cov = OnlineCovariance(self._order)
        merged_cov._count = self.count + other.count
        count_corr = (other.count * self.count) / merged_cov._count
        merged_cov._mean = (
            self.mean / other.count + other.mean / self.count
        ) * count_corr
        flat_mean_diff = self._mean - other._mean
        shp = (self._order, self._order)
        mean_diffs = np.broadcast_to(flat_mean_diff, self._shape).T
        merged_cov._cov = (
            self._cov * self.count
            + other._cov * other._count
            + mean_diffs * mean_diffs.T * count_corr
        ) / merged_cov.count
        return merged_cov



#### Save layer-wise covariance

In [22]:
import os
import os.path as osp
import torch
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from src.data.dataset_readers import DATASET_CLASSES
from src.data.Batcher import Batcher
from src.data.PytorchDataset import PytorchDataset

# HPs
DATASET_NAME = "qasc"
MODEL_NAME = "t5-base"

root_dir = osp.join(os.environ.get("SCRATCH", ""), "ties", "exp_out", "training", MODEL_NAME)
task_to_model_dict = {
    'qasc': osp.join(root_dir, "qasc", "best_model.pt"),
    'quartz': osp.join(root_dir, "quartz", "best_model.pt")
}
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
createPytorchDataset_fn = lambda dataset: PytorchDataset(dataset, tokenizer, device)

dataset_kwargs = {
    "few_shot_random_seed": None,
    "num_val_samples": 32,
    "max_datapoints_per_dataset_without_templates": None
}

dataset_reader = DATASET_CLASSES[DATASET_NAME](dataset_kwargs)
batcher = Batcher(
    dataset_reader,
    createPytorchDataset_fn,
    train_batchSize=32,
    eval_batchSize=32,
    world_size=1,
    device=0,
)
train_iterator = batcher.get_trainBatches(
    "train", 0
)


In [23]:
from tqdm import tqdm
from src.model.T5Wrapper import T5Wrapper

# Init model
transformer = AutoModelForSeq2SeqLM.from_pretrained('t5-base')
model = T5Wrapper(transformer, tokenizer)

# Optionally load state_dict
# model.load_state_dict(torch.load(task_to_model_dict['qasc']))
model.to(device)
model.train()
for _ in tqdm(range(3)):
    out = next(train_iterator)
    model(out)

  0%|          | 0/3 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 3/3 [01:06<00:00, 22.23s/it]


In [25]:
import torch
import itertools

# Compute running covariance of activations
stats = {}  # layer_name -> hook_result
handles = []  # references to hooks
MAX_HOOKS = 20
MAX_BATCHES = 20

# !! IMPORTANT !!::
# Avoid running multiple times otherwise hooks will be registered multiple times

# # Init model
# transformer = AutoModelForSeq2SeqLM.from_pretrained('t5-base')
# model = T5Wrapper(transformer, tokenizer)

# # Optionally load state_dict
# # model.load_state_dict(torch.load(task_to_model_dict['qasc']))
# model.to(device)
# model.train()

def hook(name):
    # Hook gets (module, input, output)
    def h(mod, _inp, out):
        # print("out.shape", out.shape)
        inp = _inp[0]
        if torch.is_tensor(inp):
            print("inp.shape", inp.shape)
            B, T, D = inp.shape
            if name not in stats:
                ocov = OnlineCovariance(D)
                stats[name] = ocov
            ocov = stats[name]
            for i in range(B):
                j = torch.randint(0, T, (1,)).item()
                v = inp[i, j].cpu().detach().numpy()
                # normalize v
                # v = v / np.linalg.norm(v)
                ocov.add(v)
            stats[name] = ocov
    return h

num_registered = 0
for i, (module_name, m) in enumerate(model.named_modules()):
    if isinstance(m, torch.nn.Linear):
        print("Registering hook for", module_name)
        h = m.register_forward_hook(hook(module_name))
        handles.append(h)
        if num_registered > MAX_HOOKS:
            break
        num_registered += 1

print("Registered", num_registered, "hooks")

# Run forward passes
model.to(device)
train_iterator = batcher.get_trainBatches(
    "train", 0
)
with torch.no_grad():
    print("Running forward passes")
    for i, batch in enumerate(train_iterator):
        print("Processing batch", i)
        model(batch)
        if i > MAX_BATCHES:
            break

# Clear hooks
for h in handles:
    h.remove()

Registering hook for transformer.encoder.block.0.layer.0.SelfAttention.q
Registering hook for transformer.encoder.block.0.layer.0.SelfAttention.k
Registering hook for transformer.encoder.block.0.layer.0.SelfAttention.v
Registering hook for transformer.encoder.block.0.layer.0.SelfAttention.o
Registering hook for transformer.encoder.block.0.layer.1.DenseReluDense.wi
Registering hook for transformer.encoder.block.0.layer.1.DenseReluDense.wo
Registering hook for transformer.encoder.block.1.layer.0.SelfAttention.q
Registering hook for transformer.encoder.block.1.layer.0.SelfAttention.k
Registering hook for transformer.encoder.block.1.layer.0.SelfAttention.v
Registering hook for transformer.encoder.block.1.layer.0.SelfAttention.o
Registering hook for transformer.encoder.block.1.layer.1.DenseReluDense.wi
Registering hook for transformer.encoder.block.1.layer.1.DenseReluDense.wo
Registering hook for transformer.encoder.block.2.layer.0.SelfAttention.q
Registering hook for transformer.encoder.bl

In [26]:
import pickle

results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

covs = {}
means = {}
for s in stats:
    c = stats[s].cov
    m = stats[s].mean
    print(s, c.shape)
    covs[s] = c
    means[s] = m

with open(os.path.join(results_dir, f"covs_d{DATASET_NAME}_m{MODEL_NAME}.pkl"), "wb") as f:
    pickle.dump({
        "covs": covs,
        "means": means,
    }, f)


transformer.encoder.block.0.layer.0.SelfAttention.q (768, 768)
transformer.encoder.block.0.layer.0.SelfAttention.k (768, 768)
transformer.encoder.block.0.layer.0.SelfAttention.v (768, 768)
transformer.encoder.block.0.layer.0.SelfAttention.o (768, 768)
transformer.encoder.block.0.layer.1.DenseReluDense.wi (768, 768)
transformer.encoder.block.0.layer.1.DenseReluDense.wo (3072, 3072)
transformer.encoder.block.1.layer.0.SelfAttention.q (768, 768)
transformer.encoder.block.1.layer.0.SelfAttention.k (768, 768)
transformer.encoder.block.1.layer.0.SelfAttention.v (768, 768)
transformer.encoder.block.1.layer.0.SelfAttention.o (768, 768)
transformer.encoder.block.1.layer.1.DenseReluDense.wi (768, 768)
transformer.encoder.block.1.layer.1.DenseReluDense.wo (3072, 3072)
transformer.encoder.block.2.layer.0.SelfAttention.q (768, 768)
transformer.encoder.block.2.layer.0.SelfAttention.k (768, 768)
transformer.encoder.block.2.layer.0.SelfAttention.v (768, 768)
transformer.encoder.block.2.layer.0.SelfAtt

#### Analyze layer-wise covariance

In [30]:
!pip install matplotlib seaborn

Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2


In [27]:
# Load covs and means
import torch
import numpy as np
import pickle

RESULTS_DIR = "results"
DATASET_NAME = "qasc"
MODEL_NAME = "t5-base"


with open(os.path.join(RESULTS_DIR, f"covs_d{DATASET_NAME}_m{MODEL_NAME}.pkl"), "rb") as f:
    data = pickle.load(f)
covs = data["covs"]

In [33]:
# Save to pdf
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

start_idx = 0
stop_idx = 500 + start_idx

with PdfPages(f"covs_heatmaps_start{start_idx}_stop{stop_idx}.pdf") as pdf:
    for i, (c_mat_name, c_mat) in enumerate(covs.items()):
        plt.figure(figsize=(16, 16))
        ax = sns.heatmap(
            np.abs(c_mat[start_idx:stop_idx, start_idx:stop_idx]),
            cmap="rocket",
            annot=False,
            linewidths=0,    # <-- no grid lines
        )
        ax.set_xticks([])  # <-- no ticks
        ax.set_yticks([])

        plt.title(c_mat_name.split(".")[-2:] + [f" Start {start_idx} Stop {stop_idx}"])
        pdf.savefig()
        plt.close()
