## Colab Cells

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
%cd /content/drive/MyDrive/ugs-applications/

/content/drive/MyDrive/ugs-applications


In [5]:
!pip3 install transformer_lens
!pip3 install seaborn
!pip3 install fancy_einsum
!pip3 install einops



##Imports

In [117]:
from utils.training_utils import load_model_data
import torch
from functools import partial
import torch.optim
from pruners.Pruner import Pruner
from utils.MaskConfig import VertexInferenceConfig
from utils.task_datasets import get_task_ds
from tqdm import tqdm
import pickle
from scipy.stats import spearmanr
from utils.training_utils import LinePlot
import seaborn as sns
import matplotlib.pyplot as plt
from pruners.Pruner import Pruner
from mask_samplers.AblationMaskSampler import MultiComponentMaskSampler
import sys
from utils.training_utils import load_model_data, update_means_variances_mixed

In [102]:
del sys.modules["mask_samplers.AblationMaskSampler"]

## Pruner Class

In [136]:
class KQVPruner(Pruner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, parallel_inference=True)

    def process_null_val(self, node_type, layer_no):
        if node_type in ["attn", "k", "q", "v"]:
            null_val = self.null_vals[node_type][...,layer_no,:,:]
        elif node_type == "mlp":
            null_val = self.null_vals['mlp'][...,layer_no,:]
        else:
            raise Exception("vertex type")

        if self.condition_pos:
            # seq_pos x i x n_heads x d_head
            diff = self.seq_len - null_val.shape[0]
            if diff <= 0:
                null_val = null_val[:self.seq_len]
            else:
                null_val = torch.cat([null_val, null_val[[-1]].expand(diff, *[-1 for _ in null_val.shape[1:]])], dim=0)

        return null_val


    # attentions: (batch_size + batch_size * n_samples) x seq_len x n_heads x d_model
    # constants: n_heads x d_head
    # prune mask: (batch_size * n_samples) x n_heads, 0 = prune, 1 = keep
    def pruning_hook_attention_all_tokens(self, layer_no,node_type, attentions, hook):
        bsz = self.pruning_cfg.batch_size
        if self.counterfactual_mode:
            # first batch_size are counterfactual, then next batch_size are true
            null_val = attentions[:bsz]
            bsz = 2 * bsz

            if not self.condition_pos:
                for i, p in enumerate(self.perms):
                    null_val[i,:p.shape[0]] = null_val[i,p]

            null_val = null_val.repeat(self.pruning_cfg.n_samples, 1, 1, 1)
        else:
            null_val = self.process_null_val(node_type, layer_no)

        try:
            bos_out = attentions[:,[0]].clone().detach()
            prune_mask = self.mask_sampler.sampled_mask[node_type][layer_no].unsqueeze(1).unsqueeze(-1)
            attentions[bsz:] = (
                (prune_mask < 0.001) * (1-prune_mask) * null_val
                + (prune_mask >= 0.001) * (1-prune_mask) * null_val.detach()
            ) + prune_mask * attentions[bsz:].clone()
        except Exception as e:
            print(bsz)
            print(null_val.shape)
            print(attentions.shape)
            print(prune_mask.shape)
            raise e

        # prune_idx = prune_mask.clone()
        # attentions[bsz + prune_idx[:,0],:,prune_idx[:,1]] = prune_idx * constants[prune_idx[:,1]]
        # return attentions
        attentions[:,[0]] = bos_out
        return attentions

    # attentions: (batch_size + batch_size * n_samples) x seq_len x d_model
    # constants: d_model
    # prune mask: (batch_size * n_samples), 0 = prune, 1 = keep
    def pruning_hook_mlp_all_tokens(self, layer_no, mlp_out, hook):

        bsz = self.pruning_cfg.batch_size

        if self.counterfactual_mode:
            # first batch_size are counterfactual, then next batch_size are true
            null_val = mlp_out[:bsz]
            bsz = 2 * bsz

            if not self.condition_pos:
                for i, p in enumerate(self.perms):
                    null_val[i,:p.shape[0]] = null_val[i,p]

            null_val = null_val.repeat(self.pruning_cfg.n_samples, 1, 1)
        else:
            null_val = self.process_null_val("mlp", layer_no)

        try:
            bos_out = mlp_out[:,[0]].clone().detach()
            prune_mask = self.mask_sampler.sampled_mask['mlp'][layer_no].unsqueeze(1).unsqueeze(-1)
            mlp_out[bsz:] = (
                (prune_mask < 0.001) * (1-prune_mask) * null_val
                + (prune_mask >= 0.001) * (1-prune_mask) * null_val.detach()
            ) + prune_mask * mlp_out[bsz:].clone()

            # prune_idx = prune_mask.clone()
            # attentions[bsz + prune_idx[:,0],:,prune_idx[:,1]] = prune_idx * constants[prune_idx[:,1]]

            # return mlp_out
        except Exception as e:
            print(mlp_out.shape)
            print(prune_mask.shape)
            print(null_val.shape)
            raise e

        mlp_out[:,[0]] = bos_out
        return mlp_out

    def final_hook_last_token(self, out, hook):
        bsz = self.pruning_cfg.batch_size

        # remove counterfactuals
        if self.counterfactual_mode:
            out = out[bsz:]

        if self.disable_hooks:
            out = out.unsqueeze(0)
        else:
            out = out.unflatten(0, (-1, bsz))
        out = (out * self.last_token_mask.unsqueeze(-1)).sum(dim=2)
        return out

    def get_patching_hooks(self):
        # attention_points_filter = lambda layer_no, name: name == f"blocks.{layer_no}.attn.hook_result"
        k_points_filter = lambda layer_no, name: name == f"blocks.{layer_no}.attn.hook_k"
        q_points_filter = lambda layer_no, name: name == f"blocks.{layer_no}.attn.hook_q"
        v_points_filter = lambda layer_no, name: name == f"blocks.{layer_no}.attn.hook_v"
        attention_points_filter = lambda layer_no, name: name == f"blocks.{layer_no}.attn.hook_z"
        mlp_out_filter = lambda layer_no, name: name == f"blocks.{layer_no}.hook_mlp_out"
        final_embed_filter = lambda name: name == f"blocks.{n_layers-1}.hook_resid_post"

        n_layers = self.base_model.cfg.n_layers

        return {
                **{f"patch_k_{layer_no}": (partial(k_points_filter, layer_no),
                   partial(self.pruning_hook_attention_all_tokens, layer_no, "k")
                ) for layer_no in range(n_layers)},
                **{f"patch_q_{layer_no}": (partial(q_points_filter, layer_no),
                   partial(self.pruning_hook_attention_all_tokens, layer_no, "q")
                ) for layer_no in range(n_layers)},
                **{f"patch_v_{layer_no}": (partial(v_points_filter, layer_no),
                   partial(self.pruning_hook_attention_all_tokens, layer_no, "v")
                ) for layer_no in range(n_layers)},
                # **{f"patch_mlp_{layer_no}": (partial(mlp_out_filter, layer_no),
                #    partial(self.pruning_hook_mlp_all_tokens, layer_no)
                # ) for layer_no in range(n_layers)},
                "patch_final": (final_embed_filter, self.final_hook_last_token)
        }

## Run Pruner

In [139]:
import os

def run_MCMS(ablation_type = "mean_agnostic",
              dataset = "ioi",
              n_samples = 1,
              batch_size = 100,
              model_name = "gpt2-small",
              owt_batch_size = 10,
              k = 1,
              max_batches = 10000,
              folder = "results/mcms"):

  if not os.path.exists(folder):
    print("Creating Folder", folder)
    os.makedirs(folder)

  # init model and tokenizer
  device, model, tokenizer, owt_iter = load_model_data(model_name, owt_batch_size)
  model.eval()
  n_layers = model.cfg.n_layers
  n_heads = model.cfg.n_heads

  # init pruning configs
  pruning_cfg = VertexInferenceConfig(model.cfg, device, folder, init_param=1)
  pruning_cfg.batch_size = batch_size
  pruning_cfg.n_samples = n_samples
  pruning_cfg.k = k
  print("---------------------------")
  print("Pruning Config")
  print("---------------------------")
  for k,v in pruning_cfg.__dict__.items():
    if k != "constant_prune_mask" and k!= "init_params":
      print(k,":",v)

  # init ds configs
  task_ds = get_task_ds(dataset, pruning_cfg.batch_size, device, ablation_type)

  for param in model.parameters():
      param.requires_grad = False
  print("---------------------------")
  print("Dataset Config")
  print("---------------------------")
  [print(k, ":", v) for k,v in task_ds.__dict__.items()]
  print("---------------------------")

  pruner_args = task_ds.get_pruner_args({"zero", "mean", "resample", "cf_mean", "cf", "oa", "oa_specific","mean_agnostic"})

  # init mask_sampler
  mask_sampler = MultiComponentMaskSampler(pruning_cfg)
  mask_sampler()
  print("Attn mask shape per layer", mask_sampler.sampled_mask["attn"][0].shape)
  print("MLP mask shape per layer", mask_sampler.sampled_mask["mlp"][0].shape)

  print("---------------------------")
  print("Starting Evaluation")
  print("---------------------------")

  # init vertex pruner
  vertex_pruner = KQVPruner(model, pruning_cfg, mask_sampler, **pruner_args)
  vertex_pruner.add_patching_hooks()
  vertex_pruner.modal_attention.requires_grad = False
  vertex_pruner.modal_mlp.requires_grad = False

  # init results variables
  sampling_optimizer = torch.optim.AdamW(mask_sampler.parameters(), lr=1, weight_decay=0)
  head_losses = torch.zeros((n_layers * n_heads * 3,1)).to(device)
  head_vars = torch.zeros((n_layers * n_heads * 3,1)).to(device)
  n_batches_by_head = torch.zeros_like(head_losses).to(device)
  n_samples_by_head = torch.zeros_like(head_losses).to(device)

  max_batches = int(max_batches / (batch_size * n_samples))


  for no_batches in tqdm(range(vertex_pruner.log.t, max_batches)):
      batch, last_token_pos,cf = task_ds.retrieve_batch_cf(tokenizer)
      last_token_pos = last_token_pos.int()

      sampling_optimizer.zero_grad()

      loss = vertex_pruner(batch, last_token_pos,timing = False, print_loss = False)
      loss.backward()

      atp_losses = torch.cat([ts.grad for ts in mask_sampler.mask_perturb['k']] +
                             [ts.grad for ts in mask_sampler.mask_perturb['q']] +
                             [ts.grad for ts in mask_sampler.mask_perturb['v']], dim=0).unsqueeze(-1)

      batch_n_samples = []
      for ts in mask_sampler.sampled_mask['k'] +  mask_sampler.sampled_mask['q'] + mask_sampler.sampled_mask['v']:
          batch_n_samples.append((ts < 1-1e-3).sum(dim=0))
      batch_n_samples = torch.cat(batch_n_samples, dim=0).unsqueeze(-1)


      atp_losses = torch.where(
          batch_n_samples > 0,
          atp_losses / batch_n_samples * n_samples * batch_size,
          0
      )


      head_losses, head_vars, n_batches_by_head, n_samples_by_head = update_means_variances_mixed(head_losses, head_vars, atp_losses, n_batches_by_head, n_samples_by_head, batch_n_samples)
  print("---------------------------")
  print("Finished Evaluation")
  return head_losses, head_vars, n_batches_by_head, n_samples_by_head

In [140]:
head_losses, head_vars, n_batches_by_head, n_samples_by_head = run_MCMS(max_batches = 100000)

Loading model...
Loaded pretrained model gpt2-small into HookedTransformer
Loading OWT...
Loading OWT data from disk
Making DataLoader
---------------------------
Pruning Config
---------------------------
device : cuda:0
n_layers : 12
n_heads : 12
folder : results/mcms
lamb : None
record_every : 100
checkpoint_every : 5
starting_beta : 0.6666666666666666
hard_concrete_endpoints : (-0.1, 1.1)
layers_to_prune : [('attn', 0), ('mlp', 0), ('attn', 1), ('mlp', 1), ('attn', 2), ('mlp', 2), ('attn', 3), ('mlp', 3), ('attn', 4), ('mlp', 4), ('attn', 5), ('mlp', 5), ('attn', 6), ('mlp', 6), ('attn', 7), ('mlp', 7), ('attn', 8), ('mlp', 8), ('attn', 9), ('mlp', 9), ('attn', 10), ('mlp', 10), ('attn', 11), ('mlp', 11), ('mlp', 12)]
temp_min_reg : 0.001
temp_adj_intv : 10
temp_avg_intv : 20
temp_comp_intv : 200
temp_convergence_target : 2000
temp_c : 0
temp_momentum : 0
batch_size : 100
n_samples : 1
lr : 0.01
lr_modes : 0.002
k : 1
---------------------------
Dataset Config
---------------------

100%|██████████| 1000/1000 [05:18<00:00,  3.14it/s]


---------------------------
Finished Evaluation


In [142]:
with open("atp/ioi/single_component_kqv_losses.pickle", "wb") as f:
  pickle.dump(head_losses, f)

with open("atp/ioi/single_component_kqv_vars.pickle", "wb") as f:
  pickle.dump(head_vars, f)

with open("atp/ioi/single_component_kqv_n_samples_by_head.pickle", "wb") as f:
  pickle.dump(n_samples_by_head, f)

In [110]:
model_name = "gpt2-small"
owt_batch_size = 10
batch_size = 10
ablation_type = "mean_agnostic"
dataset = "ioi"

device, model, tokenizer, owt_iter = load_model_data(model_name, owt_batch_size)
model = model.eval()
# model.cfg.use_attn_result = True

pruning_cfg = VertexInferenceConfig(model.cfg, device, "./", init_param=1,batch_size = batch_size)
task_ds = get_task_ds(dataset,batch_size, device, ablation_type)
pruner_args = task_ds.get_pruner_args({"zero", "mean", "resample", "cf_mean", "cf", "oa", "oa_specific","mean_agnostic"})

Loading model...




Loaded pretrained model gpt2-small into HookedTransformer
Loading OWT...
Loading OWT data from disk
Making DataLoader


In [114]:
pruning_cfg.init_params

{'attn': [tensor([[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.]], device='cuda:0'),
  tensor([[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.]], device='cuda:0'),
  tensor([[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.]], device='cuda:0'),
  tensor([[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.]], device='cuda:0'),
  tensor([[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.],
    

In [111]:
folder = "./"
pruning_cfg.k =10
mask_sampler = MultiComponentMaskSampler(pruning_cfg)
pruner = KQVPruner(model, pruning_cfg, mask_sampler, **pruner_args)
pruner.add_patching_hooks()
batch, last_token_pos,cf = task_ds.retrieve_batch_cf(tokenizer)

In [112]:
pruner(batch,last_token_pos)

attention hook layer 0, k
attention hook layer 0, q
attention hook layer 0, v
attention hook layer 1, k
attention hook layer 1, q
attention hook layer 1, v
attention hook layer 2, k
attention hook layer 2, q
attention hook layer 2, v
attention hook layer 3, k
attention hook layer 3, q
attention hook layer 3, v
attention hook layer 4, k
attention hook layer 4, q
attention hook layer 4, v
attention hook layer 5, k
attention hook layer 5, q
attention hook layer 5, v
attention hook layer 6, k
attention hook layer 6, q
attention hook layer 6, v
attention hook layer 7, k
attention hook layer 7, q
attention hook layer 7, v
attention hook layer 8, k
attention hook layer 8, q
attention hook layer 8, v
attention hook layer 9, k
attention hook layer 9, q
attention hook layer 9, v
attention hook layer 10, k
attention hook layer 10, q
attention hook layer 10, v
attention hook layer 11, k
attention hook layer 11, q
attention hook layer 11, v
Cuda time 6.5966081619262695
Cuda time 157.95404052734375


tensor(0.1707, device='cuda:0', grad_fn=<AddBackward0>)