# Expand backprop analysis

- train multiple models for analysis
- do backprop from final output to all of the intermediate layer outputs (as well as the inputs)
- confirm task orthogonality throughout

In [None]:
import torch
from torch import nn
from torch.optim import Adam
from matplotlib import pyplot as plt
from matplotlib import cm
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

from physics_mi.utils import set_all_seeds

In [None]:
seed = np.random.randint(1, 2**32 - 1)
# seed = 1322468781  # this one is very interesting
set_all_seeds(seed)
print(seed)

## Model

Keeping this extremely simple

In [None]:
class LinearLayer(nn.Module):
    def __init__(self, in_feats, out_feats, use_act=True, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.linear = nn.Linear(in_feats, out_feats)
        if use_act:
            self.act = nn.ReLU()
        self.use_act = use_act

    def forward(self, x):
        x = self.linear(x)
        if self.use_act:
            x = self.act(x)
        return x


class Net(nn.Module):
    def __init__(
        self, input_dim=4, hidden_dim=16, output_dim=2, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.layers = nn.Sequential(
            LinearLayer(input_dim, hidden_dim, use_act=True),
            LinearLayer(hidden_dim, hidden_dim, use_act=True),
            LinearLayer(hidden_dim, output_dim, use_act=False),
        )

    def forward(self, x):
        return self.layers(x)

## Data

In [None]:
import torch

# Number of samples
n_samples = 10000

# Epsilon value
eps = 0.5


# Generate Y values
def generate_Y(n_samples):
    return torch.rand(n_samples)


# Generate X values based on Y
def generate_X(Y, eps):
    X = torch.empty(len(Y), 2)
    X[:, 0] = Y / (torch.rand(len(Y)) * (1 - eps) + eps)
    X[:, 1] = Y / X[:, 0]

    # Randomly swap x1 and x2
    mask = torch.rand(len(Y)) < 0.5
    swap_vals = X[:, 0][mask]
    X[:, 0][mask] = X[:, 1][mask]
    X[:, 1][mask] = swap_vals

    return X


# Initial generation
Y1 = generate_Y(n_samples)
X1 = generate_X(Y1, eps)

# Ensure they are statistically independent by generating new Y and X values
Y2 = generate_Y(n_samples)
X2 = generate_X(Y2, eps)

# Stack X1 and X2 to get the desired shape
X = torch.cat((X1, X2), dim=1)

# Stack Y1 and Y2 for the desired shape
Y = torch.stack((Y1, Y2), dim=1)

# Validate the relationship
assert torch.allclose(X[:, 0] * X[:, 1], Y[:, 0])
assert torch.allclose(X[:, 2] * X[:, 3], Y[:, 1])

# Print the shapes
print(X.shape, Y.shape)

In [None]:
fig, ax = plt.subplots()

ax.hist(X[:, 0], alpha=0.5, density=True, label="mass")
ax.hist(X[:, 1], alpha=0.5, density=True, label="acceleration")
ax.hist(Y[:, 0], alpha=0.5, density=True, label="force")
ax.legend()

In [None]:
fig, ax = plt.subplots()

ax.hist(X[:, 2], alpha=0.5, density=True, label="mass")
ax.hist(X[:, 3], alpha=0.5, density=True, label="acceleration")
ax.hist(Y[:, 1], alpha=0.5, density=True, label="force")
ax.legend()

Ok, both now look identically distributed.

In [None]:
s_inds = np.random.permutation(range(X.shape[0]))  # shuffled indices

X_train = X[s_inds[:8000]]
Y_train = Y[s_inds[:8000]]
X_valid = X[s_inds[8000:]]
Y_valid = Y[s_inds[8000:]]

X_train.shape, Y_train.shape, X_valid.shape, Y_valid.shape

## Training

I'll just do full gradient descent to keep things simple.

In [None]:
N = 1000  # number of epochs
hidden_dim = 16  # number of hidden units

model = Net(input_dim=4, hidden_dim=hidden_dim, output_dim=2)
loss_func = nn.MSELoss()
optimiser = Adam(model.parameters(), lr=1e-2)
log = []

for i in tqdm(range(N)):
    log_sample = {}

    # Training update
    model.train()
    model.zero_grad()
    Y_hat = model(X_train)
    loss = loss_func(Y_hat, Y_train)
    log_sample["train_loss"] = float(loss.detach())
    loss.backward()
    optimiser.step()

    # Validation set
    model.eval()
    Y_hat = model(X_valid)
    loss = loss_func(Y_hat, Y_valid)
    log_sample["valid_loss"] = float(loss.detach())

    log.append(log_sample)

df = pd.DataFrame(log)

## Results

In [None]:
from physics_mi.eval import *


# need to avoid flattening here because we have multiple outputs
def get_preds(model, X_valid, Y_valid):
    model.eval()

    with torch.inference_mode():
        out = model(X_valid)

    y_preds = out.numpy()
    y_targs = Y_valid.numpy()

    return y_preds, y_targs

In [None]:
plot_loss(df["train_loss"], df["valid_loss"])

In [None]:
y_preds, y_targs = get_preds(model, X_valid, Y_valid)
y_preds.shape, y_targs.shape

In [None]:
get_valid_loss(model, loss_func, X_valid, Y_valid)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 10))

plot_results(y_preds[:, 0], y_targs[:, 0], ax=axes[0])
plot_results(y_preds[:, 1], y_targs[:, 1], ax=axes[1])

Both tasks seem to be doing well in parallel 👍

## Introspection

In [None]:
from physics_mi.analysis import capture_intermediate_outputs

In [None]:
def get_inputs(N=100, vary_task="A"):
    pairs = np.concatenate(
        np.stack(np.meshgrid(np.linspace(0, 1, N), np.linspace(0, 1, N))).T
    )
    pairs = torch.tensor(pairs, dtype=torch.float32)
    if vary_task == "A":
        inputs = torch.cat((pairs, torch.full((len(pairs), 2), 0.5)), dim=1)
    if vary_task == "B":
        inputs = torch.cat((torch.full((len(pairs), 2), 0.5), pairs), dim=1)
    return inputs

In [None]:
task_inputs = torch.cat(
    (get_inputs(100, vary_task="A"), get_inputs(100, vary_task="B"))
)
task_inputs.shape

In [None]:
valid_ios = capture_intermediate_outputs(model, X_valid)
task_ios = capture_intermediate_outputs(model, task_inputs)

Now we'd like to compare the principal components at `layers.0.act` with those from `layers.1.act` I think.

In [None]:
def get_pcs(data):
    mean = torch.mean(data, 0)
    data_centered = data - mean

    # Step 2: Compute the SVD
    U, S, V = torch.svd(data_centered)

    # The columns of V are the principal components
    principal_components = V

    # Step 3: Compute variances
    variances = S.pow(2) / (data.size(0) - 1)

    return principal_components, variances

In [None]:
task_acts = {}
task_acts["0"] = task_ios["layers.0.act"]
task_acts["1"] = task_ios["layers.1.act"]
task_acts["0"].shape, task_acts["1"].shape

In [None]:
valid_acts = {}
valid_acts["0"] = valid_ios["layers.0.act"]
valid_acts["1"] = valid_ios["layers.1.act"]
valid_acts["0"].shape, valid_acts["1"].shape

In [None]:
valid_pcs = {}
valid_vars = {}
valid_pcs["0"], valid_vars["0"] = get_pcs(valid_acts["0"])
valid_pcs["1"], valid_vars["1"] = get_pcs(valid_acts["1"])
valid_pcs["0"].shape, valid_pcs["1"].shape

In [None]:
def get_pc_acts(pcs, acts):
    pc_acts = (pcs.T[None, :] * acts[:, None, :]).sum(-1)
    return pc_acts

In [None]:
valid_pc_acts = {}
valid_pc_acts["0"] = get_pc_acts(valid_pcs["0"], valid_acts["0"])
valid_pc_acts["1"] = get_pc_acts(valid_pcs["1"], valid_acts["1"])
valid_pc_acts["0"].shape, valid_pc_acts["1"].shape

In [None]:
valid_ios["layers.0.act"].shape, valid_ios["layers.1.act"].shape

In [None]:
model

Now the scaffold net needs to allow me to backprop from the final output to every intermediate layer's activations (and the input as a sanity check). In this case that would be:
- layer1
- layer0
- input

Question is, do I engineer something elegant that extends to any depth architecture, or do I just do this manually for now? 🤔

The elegant solution could include an argument to the forward method that selects the layer to inspect. There's a question of whether the input would be stored as it has been in the class or whether this time it would be simpler to store it externally (as there are now multiple inputs). I think the latter is best.

In [None]:
class ScaffoldNet(Net):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.requires_grad_(False)

    def forward(self, lidx, linputs):
        sub_net = self.layers[lidx:]
        out = sub_net(linputs)
        return out


scaffold_model = ScaffoldNet()
scaffold_model.load_state_dict(model.state_dict())

In [None]:
grads = {}
tasks = ["A", "B"]
for i, task in enumerate(tasks):
    linputs = valid_ios["layers.0.act"].clone().requires_grad_(True)
    out = scaffold_model(1, linputs)
    loss = out[:, i].mean()
    loss.backward()
    grads[task] = {}
    grads[task]["gradients"] = linputs.grad.clone().detach()

for task, _ in grads.items():
    grads[task]["pcs"], grads[task]["vars"] = get_pcs(grads[task]["gradients"])
    uq_grads = grads[task]["gradients"].unique(dim=0)
    uq_grads_norm = uq_grads.norm(dim=1)
    grads[task]["unique"] = {}
    grads[task]["unique"]["comps"] = uq_grads / uq_grads_norm[:, None]
    grads[task]["unique"]["norms"] = uq_grads_norm

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

for ax, task in zip(axes, ("A", "B")):
    ax.bar(range(1, 17), grads[task]["vars"])
    ax.set_title(f"Task {task}")
    ax.set_xlabel("PC")
    ax.set_ylabel("Variance")

fig.suptitle("Task-wise principal components of activation gradients")
fig.set_tight_layout("tight")

The mean is crude but it could be informative:

In [None]:
meanA = grads["A"]["gradients"].mean(0)
meanB = grads["B"]["gradients"].mean(0)
meanA = meanA / meanA.norm()
meanB = meanB / meanB.norm()

torch.dot(meanA, meanB)

In [None]:
sims = torch.einsum(
    "ij,kj->ik", grads["A"]["unique"]["comps"], grads["B"]["unique"]["comps"]
).numpy()

In [None]:
def plot_similarity(sims, title="Dot-product Similarity", x_label=None):
    fig, ax = plt.subplots(figsize=(6, 6))

    im = ax.imshow(sims, cmap="bwr", vmin=-1, vmax=1)
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=14)

    ax.set_title(title, fontsize=16)

    # I'm pretty sure this is the right way round from ij,ik->jk?
    ax.set_xlabel("Task B PCA", fontsize=14)
    ax.set_ylabel("Task A PCA", fontsize=14)

In [None]:
plot_similarity(sims, title="Dot-product Similarity")

In [None]:
sliceA = slice(0, 10000)
sliceB = slice(10000, None)

### Task A PCs

I'm first focusing on the gradient PCs extracted from backprop from the task A output and how much variance they explain in both tasks.

In [None]:
task = "A"

Calculating how much variance is explained by these gradient PCs across both task datasets:

In [None]:
sort_idxs = grads[task]["unique"]["norms"].argsort(descending=True)
norms = grads[task]["unique"]["norms"][sort_idxs]
grad_acts = torch.einsum("ij,kj->ki", grads[task]["unique"]["comps"], task_acts["0"])[
    :, sort_idxs
]
varA = grad_acts[sliceA].var(0)
varB = grad_acts[sliceB].var(0)
varA.shape, varB.shape

In [None]:
def plot_pca_variances(varA, varB, grad_pc_variance, err_varA=None, err_varB=None):
    fig, ax1 = plt.subplots(figsize=(8, 6))

    # Plot the variances for Task A and Task B
    ax1.bar(
        range(len(varA)),
        varA,
        yerr=err_varA,
        width=0.4,
        align="center",
        label="Task A activations",
        alpha=0.5,
    )
    ax1.bar(
        range(len(varB)),
        varB,
        yerr=err_varB,
        width=0.4,
        align="center",
        label="Task B activations",
        alpha=0.5,
    )
    ax1.set_xlabel("Unique Gradient Component")
    ax1.set_ylabel("Variance")
    ax1.legend(loc="upper left")

    # Create a second y-axis for PC importance
    ax2 = ax1.twinx()
    ax2.scatter(
        range(len(grad_pc_variance)),
        grad_pc_variance,
        label="Gradient Component Norm",
        color="r",
        marker="o",
    )
    ax2.set_ylim(0)
    ax2.set_ylabel("Norm")
    ax2.legend(loc="upper right")

    ax1.set_title(
        "Task-wise variance explained in activations by each unique gradient component"
    )

In [None]:
plot_pca_variances(varA, varB, norms)

### Task B PCs

In [None]:
task = "B"

In [None]:
sort_idxs = grads[task]["unique"]["norms"].argsort(descending=True)
norms = grads[task]["unique"]["norms"][sort_idxs]
grad_acts = torch.einsum("ij,kj->ki", grads[task]["unique"]["comps"], task_acts["0"])[
    :, sort_idxs
]
varA = grad_acts[sliceA].var(0)
varB = grad_acts[sliceB].var(0)
varA.shape, varB.shape

In [None]:
plot_pca_variances(varA, varB, norms)