In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from utils import *

import time
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_name = "meta-llama/llama-2-7b-hf", 4096, tokenizer, train=True)

Loading checkpoint shards: 100%|██████████| 2/2 [01:06<00:00, 33.23s/it]


Loading FineWeb-Edu v2
Total tokens loaded: 8388608


In [None]:
from typing import Dict, Tuple


N = 10 # possible sparsity levels (0.0 - 0.9)
S = 3 # number of state features
        

class Environment:
    def __init__(self, model_name:str, num_samples:int, sequence_length:int, target_sparsity:float=0.5)->None:
        self.model_name = model_name
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.target_sparsity = target_sparsity
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.possible_sparsities = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # initialize state
        self.load_calibration_data()
        self.reset()

    def load_calibration_data(self):
        # caliberation data
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.tokenizer = tokenizer
        num_tokens = self.num_samples * self.sequence_length
        self.calib_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=True)
        # self.test_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=False)
        _, self.test_data = get_w2_data(1, self.sequence_length, tokenizer)

    @torch.no_grad()
    def init(self) -> None:
        # create model, tokenizer, and calibration data.
        # model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype=torch.float16, device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.model = model
        self.tokenizer = tokenizer

        # caliberation data
        test_data = self.test_data

        # env attributes
        self.action_mask = torch.ones(N)
        self.layers = model.model.layers
        self.num_layers = len(self.layers)
        self.current_layer = 0
        self.global_sparsity = 0.0
        self.layer_sparsities = [0.0] * self.num_layers
        self.pruning_info = {}

        # buffers
        self.inps = torch.zeros((self.num_samples, self.sequence_length, model.config.hidden_size), dtype=torch.float16, device=self.device)
        self.outs = torch.zeros_like(self.inps)
        self.inp_kwargs = {}

        # obtain input into the first decoder layer
        cache = model.config.use_cache
        model.config.use_cache = False
        inps = self.inps
        inp_kwargs = self.inp_kwargs
        class catch_inps(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
                self.num_inps = 0
            def forward(self, inp, **kwargs):
                nonlocal inps, inp_kwargs
                inps[self.num_inps] = inp
                inp_kwargs.update(kwargs)
                self.num_inps += 1
                raise Exception("caught inps. Stopping forward pass.")
        self.layers[0] = catch_inps(self.layers[0])
        for sample in self.calib_data:
            try:
                model(sample.to(self.device))
            except Exception as e:
                pass
        self.layers[0] = self.layers[0].module
        self.inps = inps
        self.inp_kwargs = inp_kwargs

        # save the log targets to a file for computing the KL divergence later
        i_batches = 0
        os.makedirs(f"logs/kl/{self.model_name}", exist_ok=True)
        batch_size = 4
        log_probs = []
        for j in range(self.num_samples):
            if os.path.exists(f"logs/kl/{self.model_name}/log_targets_{(j//batch_size)}_{batch_size}.pt"):
                i_batches = j // batch_size
                continue
            sample = test_data[j]
            logits = model(sample.to(self.device)).logits
            log_probs.append(F.log_softmax(logits.float(), dim=-1).reshape(-1, model.config.vocab_size).cpu())
            if j % batch_size == batch_size-1:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                log_probs = []
            elif j == self.num_samples - 1 and len(log_probs) > 0:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
            i_batches = j // batch_size
        # print(f"Saved {i_batches+1} batches of log probabilities for KL divergence computation.")
        # create a dataloader for computing KL divergence later
        model_name = self.model_name
        class KLDataset(torch.utils.data.Dataset):
            def __init__(self):
                self.path_format = f"logs/kl/{model_name}"+"/log_targets_{}_{}.pt"
            def __len__(self):
                return i_batches + 1
            def __getitem__(self, idx):
                nonlocal batch_size
                samples = torch.cat(test_data[idx*batch_size:(idx+1)*batch_size], dim=0)
                log_probs = torch.load(self.path_format.format(idx, batch_size))
                return samples, log_probs
        self.kl_dataloader = torch.utils.data.DataLoader(KLDataset(), batch_size=1, shuffle=False)
        # print(f"KL dataloader with {len(self.kl_dataloader)} batches created.")
        model.config.use_cache = cache

    def prune_layer(self, layer_idx:int, sparsity:float)->None:
        if layer_idx in self.pruning_info:
            raise Exception(f"Layer {layer_idx} already pruned. Skipping.")
        
        layer = self.layers[layer_idx]
        sublayers = {name: module for name, module in layer.named_modules() if isinstance(module, nn.Linear)}
        wrapped_layers = {}
        for name, sublayer in sublayers.items():
            wrapped_layers[name] = WrappedGPT(sublayer)

        # obtain the input activations to each sublayer, computing the feature-wise norms
        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp
        handles = []
        for name in wrapped_layers:
            handles.append(sublayers[name].register_forward_hook(add_batch(name)))
        for j in range(self.num_samples):
            self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]
        for h in handles:
            h.remove()
        
        for name in sublayers:
            wrapped_layers[name].prune(sparsity)
            wrapped_layers[name].clean()

        # outputs after pruning
        for j in range(self.num_samples):
            with torch.no_grad():
                self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]

        # the output from this layer should be the input to the next layer
        self.inps, self.outs = self.outs, self.inps

        # done pruning this layer. Prepare some info about this layer's pruning
        obtained_sparsity = np.mean([l.weight.data.eq(0).float().mean().item() for l in sublayers.values()]).item()
        info = {
            "layer": layer_idx,
            "layer_target_sparsity": sparsity,
            "layer_obtained_sparsity": obtained_sparsity,
        }
        self.pruning_info[layer_idx] = info

    def reset(self) -> Dict[str, torch.Tensor]:
        if hasattr(self, "inps"):
            del self.inps, self.outs, self.inp_kwargs
            del self.kl_dataloader
            del self.model, self.tokenizer
        torch.cuda.empty_cache()
        self.init()
        return self.get_state()

    def get_state(self) -> Dict[str, torch.Tensor]:
        s = [self.global_sparsity, self.target_sparsity, self.current_layer / self.num_layers]
        if self.current_layer == 0:
            mask = [1] * len(self.possible_sparsities)
        else:
            mask = [1 if (sum(self.layer_sparsities[:self.current_layer]) + s) / self.current_layer <= self.target_sparsity else 0 for s in self.possible_sparsities]
        state = {
            "state": torch.tensor(s, dtype=torch.float32),
            "action_mask": torch.tensor(mask, dtype=torch.float32)
        }
        return state

    @torch.no_grad()
    def step(self, action:int)->Tuple[Dict[str, torch.Tensor], float, bool, Dict[str, object]]:
        sparsity = self.possible_sparsities[action]
        self.prune_layer(self.current_layer, sparsity)
        # update global sparsity
        self.layer_sparsities[self.current_layer] = sparsity
        self.current_layer += 1
        self.global_sparsity = np.mean(self.layer_sparsities[:self.current_layer])
        # compute reward
        reward = 0
        done = self.current_layer == self.num_layers
        if done:
            # compute KL divergence between the pruned and unpruned model.
            # the logits have been saved to a file during initialization.
            running_kl = 0.0
            total_logprobs = 0
            # for batch in self.kl_dataloader:
            #     inps, target_log_probs = [batch[0].squeeze(0), batch[1].squeeze(0)]
            #     logits = self.model(inps.to(self.device)).logits.reshape(-1, self.model.config.vocab_size)
            #     log_probs = F.log_softmax(logits.float(), dim=-1)
            #     kl = F.kl_div(log_probs, target_log_probs.to(self.device), reduction="batchmean", log_target=True).item()
            #     running_kl *= (total_logprobs / (total_logprobs + target_log_probs.numel()))
            #     running_kl += (target_log_probs.numel() / (total_logprobs + target_log_probs.numel())) * kl
            #     total_logprobs += target_log_probs.numel()
            #     del target_log_probs, logits, kl
            #     torch.cuda.empty_cache()
            reward = running_kl
            reward = eval_ppl(self.model, self.test_data, self.sequence_length, device=self.device)

        return self.get_state(), reward, done, {}

In [11]:
# RL Agent
from typing import Optional


class RLActor(nn.Module):
    def __init__(self, state_size:int, action_size:int, device:str="cuda"):
        super(RLActor, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.device = device
        self.base = nn.Sequential(
            nn.Linear(S+N, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        self.head = nn.Linear(256, N)
        self.uniform_init()

    def forward(self, state:Dict[str, torch.Tensor]) -> torch.Tensor:
        s = state["state"]
        action_mask = state["action_mask"]
        large_neg = torch.finfo(s.dtype).min
        x = torch.cat([s, action_mask], dim=-1).to(self.device)
        x = self.base(x)
        logits = self.head(x)
        logits = torch.where(action_mask.to(self.device) == 1, logits, large_neg)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        return dist

    def uniform_init(self):
        bias = self.head.bias.data.detach().clone()
        bias = torch.ones_like(bias)*(1/self.action_size)
        self.head.bias.data.copy_(bias)

    @torch.no_grad()
    def act(self, state:Dict[str, torch.Tensor], deterministic=False) -> tuple[torch.Tensor, torch.Tensor]:
        state = {k: v.unsqueeze(0) for k, v in state.items()}
        dist = self(state)
        action = dist.sample() if not deterministic else dist.mode
        log_prob = dist.log_prob(action)
        return action, log_prob


class Critic(nn.Module):
    def __init__(self, state_size:int, action_size:int):
        super(Critic, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.model = nn.Sequential(
            nn.Linear(S, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, state:torch.Tensor) -> torch.Tensor:
        return self.model(state)

In [12]:
class RLLearner:
    def __init__(self, model: nn.Module, gamma: float = 1.0, lr=1e-4, device: str = "cuda"):
        self.model = model
        self.gamma = gamma
        self.device = device
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def __call__(self, trajectories):
        R = 0
        returns = []
        for _, _, reward, log_prob in reversed(trajectories):
            R = reward + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns).to(self.device)
        baseline = 6.0
        returns = (returns - baseline)
        bs = 4
        loss = 0
        self.optimizer.zero_grad()
        scale = 1/len(returns)
        for i in range(0, len(returns), bs):
            j = min(i+bs, len(returns))
            trans = trajectories[i:j]
            r = returns[i:j]
            x = {
                    "state": torch.stack([t[0]["state"] for t in trans], dim=0).to(self.device),
                    "action_mask": torch.stack([t[0]["action_mask"] for t in trans], dim=0).to(self.device)
            }
            a = torch.cat([t[1] for t in trans]).to(self.device)
            dist = self.model(x)
            loss += (scale)*(dist.log_prob(a) * r).sum()
        loss.backward()
        # for (state, action, _, log_prob), R in zip(trajectories, returns):
        #     # self.model(state[])
        #     print(log_prob)
        #     policy_loss.append(log_prob * R)
        # self.optimizer.zero_grad()
        # policy_loss = torch.cat(policy_loss).sum()
        # policy_loss.backward()
        # self.optimizer.step()
        # return policy_loss.item()
        return loss.item()
    

In [13]:
class ActorRollout:
    def __init__(self, env: Environment, actor: RLActor):
        self.env = env
        self.actor = actor

    def __call__(self):
        state = self.env.reset()
        done = False
        trajectories = []
        while not done:
            action, log_prob = self.actor.act(state)
            next_state, reward, done, _ = self.env.step(action.item())
            trajectories.append((state, action, reward, log_prob))
            state = next_state
        return trajectories

In [14]:
class Trainer:
    def __init__(self, env: Environment, actor: RLActor, gamma: float = 0.99, lr: float = 1e-4):
        self.env = env
        self.actor = actor
        self.learner = RLLearner(actor, gamma=gamma, lr=lr, device=actor.device)
        self.rollout = ActorRollout(env, actor)

    def __call__(self, num_episodes:int=100):
        for episode in range(num_episodes):
            start_time = time.time()
            trajectories = self.rollout()
            loss = self.learner(trajectories)
            end_time = time.time()
            kl = trajectories[-1][2]
            print(f"Episode {episode+1}/{num_episodes}, Loss: {loss:.4f}, KL:{kl}, Time: {end_time - start_time:.2f}s")

In [15]:
environment = Environment(model_name, num_samples=128, sequence_length=4096, target_sparsity=0.5)
actor = RLActor(state_size=S, action_size=N, device="cuda")
actor.to(actor.device)
trainer = Trainer(environment, actor, lr=1e-4)

Loading FineWeb-Edu v2
Total tokens loaded: 524288


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


In [16]:
trainer(20)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.30s/it]


Episode 1/20, Loss: -30.4306, KL:22.85381317138672, Time: 108.40s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.35s/it]


Episode 2/20, Loss: -1.1163, KL:7.542420387268066, Time: 110.07s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it]


Episode 3/20, Loss: -19.6847, KL:17.155426025390625, Time: 94.37s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.39s/it]


Episode 4/20, Loss: -19.8593, KL:17.34784698486328, Time: 112.75s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it]


Episode 5/20, Loss: -4.1235, KL:9.071385383605957, Time: 106.55s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


Episode 6/20, Loss: -1.3517, KL:7.661917686462402, Time: 97.04s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it]


Episode 7/20, Loss: -643.3914, KL:332.116943359375, Time: 111.63s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.31s/it]


Episode 8/20, Loss: -5.7742, KL:9.996315002441406, Time: 98.63s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


Episode 9/20, Loss: -8.4242, KL:11.264843940734863, Time: 89.34s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.34s/it]


Episode 10/20, Loss: 0.0755, KL:6.900660037994385, Time: 111.36s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.35s/it]


Episode 11/20, Loss: -5.9629, KL:10.077919006347656, Time: 106.39s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


Episode 12/20, Loss: -2.9005, KL:8.48875617980957, Time: 91.52s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


Episode 13/20, Loss: -4.8548, KL:9.553827285766602, Time: 106.47s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


Episode 14/20, Loss: -8.0569, KL:11.159964561462402, Time: 105.77s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.29s/it]


Episode 15/20, Loss: -5.3118, KL:9.735267639160156, Time: 89.48s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


Episode 16/20, Loss: -22.6391, KL:18.800804138183594, Time: 112.17s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.38s/it]


Episode 17/20, Loss: -1.5397, KL:7.766565322875977, Time: 108.86s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.31s/it]


Episode 18/20, Loss: -13.5594, KL:13.879164695739746, Time: 90.51s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.31s/it]


Episode 19/20, Loss: -4.0738, KL:9.069145202636719, Time: 101.13s


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


Episode 20/20, Loss: -11.4958, KL:12.84262466430664, Time: 111.72s


In [None]:
model_name = "meta-llama/llama-2-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
w2_train_data, w2_test_data = get_w2_data(128, 4096, tokenizer)

num_tokens = 2**23
fw_train_data = get_fineweb_edu(num_tokens, 4096, tokenizer, train=True)



In [None]:
prune_default(model, fw_train_data, 0.5, theta1=0.42, theta2=0.51, theta3=0.38, is_sparsegpt=True, device=torch.device("cuda:0"))

In [None]:
eval_ppl(model, w2_test_data, 4096,  bs=1, device="cuda")