In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from utils import *

import time
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_name = "meta-llama/llama-2-7b-hf"

In [None]:
N = 10 # possible sparsity levels (0.0 - 0.9)
S = 3 # number of state features

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.base = nn.Sequential(
            nn.Linear(S+N, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        self.head = nn.Linear(256, N)

    def forward(self, x):
        x = self.base(x)
        logits = self.head(x)
        return logits

class Environment:
    def __init__(self, model_name, num_samples, sequence_length, target_sparsity=0.5):
        self.model_name = model_name
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.target_sparsity = target_sparsity
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.possible_sparsities = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # initialize state
        self.reset()
        self.pruning_info = {}

    @torch.no_grad()
    def init(self):
        # create model, tokenizer, and calibration data.
        # model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype=torch.float16, device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.model = model
        self.tokenizer = tokenizer

        # caliberation data
        num_tokens = self.num_samples * self.sequence_length
        calib_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=True)
        test_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=False)

        # env attributes
        self.action_mask = torch.ones(N)
        self.layers = model.model.layers
        self.num_layers = len(self.layers)
        self.current_layer = 0
        self.global_sparsity = 0.0
        self.layer_sparsities = [0.0] * self.num_layers
        self.calib_data = calib_data

        # buffers
        self.inps = torch.zeros((self.num_samples, self.sequence_length, model.config.hidden_size), dtype=torch.float16, device=self.device)
        self.outs = torch.zeros_like(self.inps)
        self.inp_kwargs = {}

        # obtain input into the first decoder layer
        cache = model.config.use_cache
        model.config.use_cache = False
        inps = self.inps
        inp_kwargs = self.inp_kwargs
        class catch_inps(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
                self.num_inps = 0
            def forward(self, inp, **kwargs):
                nonlocal inps, inp_kwargs
                inps[self.num_inps] = inp
                inp_kwargs.update(kwargs)
                self.num_inps += 1
                raise Exception("caught inps. Stopping forward pass.")
        self.layers[0] = catch_inps(self.layers[0])
        for sample in self.calib_data:
            try:
                model(sample.to(self.device))
            except Exception as e:
                pass
        self.layers[0] = self.layers[0].module
        self.inps = inps
        self.inp_kwargs = inp_kwargs

        # save the log targets to a file for computing the KL divergence later
        i_batches = 0
        os.makedirs(f"logs/kl/{self.model_name}", exist_ok=True)
        batch_size = 4
        log_probs = []
        for j in range(self.num_samples):
            if os.path.exists(f"logs/kl/{self.model_name}/log_targets_{(j//batch_size)}_{batch_size}.pt"):
                i_batches = j // batch_size
                continue
            sample = test_data[j]
            logits = model(sample.to(self.device)).logits
            log_probs.append(F.log_softmax(logits.float(), dim=-1).reshape(-1, model.config.vocab_size).cpu())
            if j % batch_size == batch_size-1:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                log_probs = []
            elif j == self.num_samples - 1 and len(log_probs) > 0:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
            i_batches = j // batch_size
        print(f"Saved {i_batches+1} batches of log probabilities for KL divergence computation.")
        # create a dataloader for computing KL divergence later
        model_name = self.model_name
        class KLDataset(torch.utils.data.Dataset):
            def __init__(self):
                self.path_format = f"logs/{model_name}"+"/log_targets_{}_{}.pt"
            def __len__(self):
                return i_batches + 1
            def __getitem__(self, idx):
                nonlocal batch_size
                samples = torch.cat(test_data[idx*batch_size:(idx+1)*batch_size], dim=0)
                log_probs = torch.load(self.path_format.format(idx, batch_size))
                return samples, log_probs
        self.kl_dataloader = torch.utils.data.DataLoader(KLDataset(), batch_size=1, shuffle=False)
        print(f"KL dataloader with {len(self.kl_dataloader)} batches created.")
        model.config.use_cache = cache

    def prune_layer(self, layer_idx, sparsity):
        if layer_idx in self.pruning_info:
            raise Exception(f"Layer {layer_idx} already pruned. Skipping.")
        
        layer = self.layers[layer_idx]
        sublayers = {name: module for name, module in layer.named_modules() if isinstance(module, nn.Linear)}
        wrapped_layers = {}
        for name, sublayer in sublayers.items():
            wrapped_layers[name] = WrappedGPT(sublayer)

        # obtain the input activations to each sublayer, computing the feature-wise norms
        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp
        handles = []
        for name in wrapped_layers:
            handles.append(sublayers[name].register_forward_hook(add_batch(name)))
        for j in range(self.num_samples):
            self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]
        for h in handles:
            h.remove()
        
        for name in sublayers:
            wrapped_layers[name].prune(sparsity)
            wrapped_layers[name].clean()

        # outputs after pruning
        for j in range(self.num_samples):
            with torch.no_grad():
                self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]

        # the output from this layer should be the input to the next layer
        self.inps, self.outs = self.outs, self.inps

        # done pruning this layer. Prepare some info about this layer's pruning
        obtained_sparsity = np.mean([l.weight.data.eq(0).float().mean().item() for l in sublayers.values()]).item()
        info = {
            "layer": layer_idx,
            "layer_target_sparsity": sparsity,
            "layer_obtained_sparsity": obtained_sparsity,
        }
        self.pruning_info[layer_idx] = info

    def reset(self):
        if hasattr(self, "inps"):
            del self.inps, self.outs, self.inp_kwargs
            del self.kl_dataloader
            del self.model, self.tokenizer
        torch.cuda.empty_cache()
        self.init()
        return self.get_state()

    def get_state(self):
        s = [self.global_sparsity, self.target_sparsity, self.current_layer / self.num_layers]
        if self.current_layer == 0:
            mask = [1] * len(self.possible_sparsities)
        else:
            mask = [1 if (sum(self.layer_sparsities[:self.current_layer]) + s) / self.current_layer <= self.target_sparsity else 0 for s in self.possible_sparsities]
        state = {
            "state": torch.tensor(s, dtype=torch.float32),
            "action_mask": torch.tensor(mask, dtype=torch.float32)
        }
        return state

    @torch.no_grad()
    def step(self, action):
        sparsity = self.possible_sparsities[action]
        self.prune_layer(self.current_layer, sparsity)
        # update global sparsity
        self.layer_sparsities[self.current_layer] = sparsity
        self.current_layer += 1
        self.global_sparsity = np.mean(self.layer_sparsities[:self.current_layer])
        # compute reward
        reward = 0
        done = self.current_layer == self.num_layers
        if done:
            # compute KL divergence between the pruned and unpruned model.
            # the logits have been saved to a file during initialization.
            running_kl = 0.0
            total_logprobs = 0
            for batch in self.kl_dataloader:
                inps, target_log_probs = [batch[0].squeeze(0), batch[1].squeeze(0)]
                logits = self.model(inps.to(self.device)).logits.reshape(-1, self.model.config.vocab_size)
                log_probs = F.log_softmax(logits.float(), dim=-1)
                kl = F.kl_div(log_probs, target_log_probs.to(self.device), reduction="batchmean", log_target=True).item()
                running_kl *= (total_logprobs / (total_logprobs + target_log_probs.numel()))
                running_kl += (target_log_probs.numel() / (total_logprobs + target_log_probs.numel())) * kl
                total_logprobs += target_log_probs.numel()
                del target_log_probs, logits, kl
                torch.cuda.empty_cache()
            reward = -running_kl

        return self.get_state(), reward, done, {}

In [None]:
environment = Environment(model_name, num_samples=128, sequence_length=4096, target_sparsity=0.5)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.14s/it]


Loading FineWeb-Edu v2
Total tokens loaded: 524288
Loading FineWeb-Edu v2
Total tokens loaded: 524288
Saved 32 batches of log probabilities for KL divergence computation.
KL dataloader with 32 batches created.


In [5]:
state = environment.reset()
done = False
while not done:
    action = 5
    lidx = environment.current_layer
    state, reward, done, _ = environment.step(action)
    print(f"Layer: {lidx}, Action: {action}, Reward: {reward}, Done: {done}")
    

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.18s/it]


Loading FineWeb-Edu v2
Total tokens loaded: 524288
Loading FineWeb-Edu v2
Total tokens loaded: 524288
Saved 32 batches of log probabilities for KL divergence computation.
KL dataloader with 32 batches created.
Layer: 0, Action: 5, Reward: 0, Done: False
Layer: 1, Action: 5, Reward: 0, Done: False
Layer: 2, Action: 5, Reward: 0, Done: False
Layer: 3, Action: 5, Reward: 0, Done: False
Layer: 4, Action: 5, Reward: 0, Done: False
Layer: 5, Action: 5, Reward: 0, Done: False
Layer: 6, Action: 5, Reward: 0, Done: False
Layer: 7, Action: 5, Reward: 0, Done: False
Layer: 8, Action: 5, Reward: 0, Done: False
Layer: 9, Action: 5, Reward: 0, Done: False
Layer: 10, Action: 5, Reward: 0, Done: False
Layer: 11, Action: 5, Reward: 0, Done: False
Layer: 12, Action: 5, Reward: 0, Done: False
Layer: 13, Action: 5, Reward: 0, Done: False
Layer: 14, Action: 5, Reward: 0, Done: False
Layer: 15, Action: 5, Reward: 0, Done: False
Layer: 16, Action: 5, Reward: 0, Done: False
Layer: 17, Action: 5, Reward: 0, Do

In [6]:
environment.pruning_info

{0: {'layer': 0, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 1: {'layer': 1, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 2: {'layer': 2, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 3: {'layer': 3, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 4: {'layer': 4, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 5: {'layer': 5, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 6: {'layer': 6, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 7: {'layer': 7, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 8: {'layer': 8, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 9: {'layer': 9, 'layer_target_sparsity': 0.5, 'layer_obtained_sparsity': 0.5},
 10: {'layer': 10,
  'layer_target_sparsity': 0.5,
  'layer_obtained_sparsity': 0.5},
 11: {'layer': 11,
  'layer_target_sparsity': 0.5,
  'layer_obtained_sparsity': 0.5},
 12: {'layer': 12,
  'layer_