In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from typing import Optional, Dict, Tuple


import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import gymnasium as gym
try:
    import wandb
except ImportError:
    print("wandb not installed. WandBLogger will not work.")
    wandb = None

device = "cuda" if torch.cuda.is_available() else "cpu"

from utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:

class Environment:
    def __init__(self, model_name:str, num_samples:int, sequence_length:int, target_sparsity:float=0.5)->None:
        self.model_name = model_name
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.target_sparsity = target_sparsity
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.possible_sparsities = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.device = device if torch.cuda.is_available() else "cpu"
        
        # initialize state
        self.load_calibration_data()
        self.reset()

    def load_calibration_data(self):
        # caliberation data
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.tokenizer = tokenizer
        num_tokens = self.num_samples * self.sequence_length
        self.calib_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=True)
        # self.test_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=False)
        _, self.test_data = get_w2_data(self.num_samples, self.sequence_length, tokenizer)

    @torch.no_grad()
    def init(self) -> None:
        # create model, tokenizer, and calibration data.
        # model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype=torch.float16, device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.model = model
        self.tokenizer = tokenizer

        # caliberation data
        test_data = self.test_data

        # env attributes
        self.action_mask = torch.ones(N)
        self.layers = model.model.layers
        self.num_layers = len(self.layers)
        self.current_layer = 0
        self.global_sparsity = 0.0
        self.layer_sparsities = [0.0] * self.num_layers
        self.pruning_info = {}

        # buffers
        self.inps = torch.zeros((self.num_samples, self.sequence_length, model.config.hidden_size), dtype=torch.float16, device=self.device)
        self.outs = torch.zeros_like(self.inps)
        self.inp_kwargs = {}

        # obtain input into the first decoder layer
        cache = model.config.use_cache
        model.config.use_cache = False
        inps = self.inps
        inp_kwargs = self.inp_kwargs
        class catch_inps(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
                self.num_inps = 0
            def forward(self, inp, **kwargs):
                nonlocal inps, inp_kwargs
                inps[self.num_inps] = inp
                inp_kwargs.update(kwargs)
                self.num_inps += 1
                raise Exception("caught inps. Stopping forward pass.")
        self.layers[0] = catch_inps(self.layers[0])
        for sample in self.calib_data:
            try:
                model(sample.to(self.device))
            except Exception as e:
                pass
        self.layers[0] = self.layers[0].module
        self.inps = inps
        self.inp_kwargs = inp_kwargs

        # save the log targets to a file for computing the KL divergence later
        i_batches = 0
        os.makedirs(f"logs/kl/{self.model_name}", exist_ok=True)
        batch_size = 4
        log_probs = []
        for j in range(self.num_samples):
            if os.path.exists(f"logs/kl/{self.model_name}/log_targets_{(j//batch_size)}_{batch_size}.pt"):
                i_batches = j // batch_size
                continue
            sample = test_data[j]
            logits = model(sample.to(self.device)).logits
            log_probs.append(F.log_softmax(logits.float(), dim=-1).reshape(-1, model.config.vocab_size).cpu())
            if j % batch_size == batch_size-1:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                print(f"Saved logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                log_probs = []
            elif j == self.num_samples - 1 and len(log_probs) > 0:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                print(f"Saved logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
            i_batches = j // batch_size
            
        # create a dataloader for computing KL divergence later
        model_name = self.model_name
        class KLDataset(torch.utils.data.Dataset):
            def __init__(self):
                self.path_format = f"logs/kl/{model_name}"+"/log_targets_{}_{}.pt"
            def __len__(self):
                return i_batches + 1
            def __getitem__(self, idx):
                nonlocal batch_size
                samples = torch.cat(test_data[idx*batch_size:(idx+1)*batch_size], dim=0)
                log_probs = torch.load(self.path_format.format(idx, batch_size))
                return samples, log_probs
        self.kl_dataloader = torch.utils.data.DataLoader(KLDataset(), batch_size=1, shuffle=False)
        # print(f"KL dataloader with {len(self.kl_dataloader)} batches created.")
        model.config.use_cache = cache

    def prune_layer(self, layer_idx:int, sparsity:float)->None:
        if layer_idx in self.pruning_info:
            raise Exception(f"Layer {layer_idx} already pruned. Skipping.")
        
        layer = self.layers[layer_idx]
        sublayers = {name: module for name, module in layer.named_modules() if isinstance(module, nn.Linear)}
        wrapped_layers = {}
        for name, sublayer in sublayers.items():
            wrapped_layers[name] = WrappedGPT(sublayer)

        # obtain the input activations to each sublayer, computing the feature-wise norms
        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp
        handles = []
        for name in wrapped_layers:
            handles.append(sublayers[name].register_forward_hook(add_batch(name)))
        for j in range(self.num_samples):
            self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]
        for h in handles:
            h.remove()
        
        for name in sublayers:
            wrapped_layers[name].prune(sparsity)
            wrapped_layers[name].clean()

        # outputs after pruning
        for j in range(self.num_samples):
            with torch.no_grad():
                self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]

        # the output from this layer should be the input to the next layer
        self.inps, self.outs = self.outs, self.inps

        # done pruning this layer. Prepare some info about this layer's pruning
        obtained_sparsity = np.mean([l.weight.data.eq(0).float().mean().item() for l in sublayers.values()]).item()
        info = {
            "layer": layer_idx,
            "layer_target_sparsity": sparsity,
            "layer_obtained_sparsity": obtained_sparsity,
        }
        self.pruning_info[layer_idx] = info

    def reset(self) -> Dict[str, torch.Tensor]:
        if hasattr(self, "inps"):
            del self.inps, self.outs, self.inp_kwargs
            del self.kl_dataloader
            del self.model, self.tokenizer
        torch.cuda.empty_cache()
        self.init()
        return self.get_state(), {}

    def get_state(self) -> Dict[str, torch.Tensor]:
        s = [self.global_sparsity, self.target_sparsity, self.current_layer / self.num_layers]
        if self.current_layer == 0:
            mask = [1] * len(self.possible_sparsities)
        else:
            mask = [1 if (sum(self.layer_sparsities[:self.current_layer]) + s) / self.current_layer <= self.target_sparsity else 0 for s in self.possible_sparsities]
        state = {
            "state": torch.tensor(s, dtype=torch.float32),
            "action_mask": torch.tensor(mask, dtype=torch.float32)
        }
        return state

    @torch.no_grad()
    def step(self, action:int)->Tuple[Dict[str, torch.Tensor], float, bool, Dict[str, object]]:
        sparsity = self.possible_sparsities[action]
        self.prune_layer(self.current_layer, sparsity)
        # update global sparsity
        self.layer_sparsities[self.current_layer] = sparsity
        self.current_layer += 1
        self.global_sparsity = np.mean(self.layer_sparsities[:self.current_layer])
        # compute reward
        reward = 0
        done = self.current_layer == self.num_layers
        if done:
            # compute KL divergence between the pruned and unpruned model.
            # the logits have been saved to a file during initialization.
            running_kl = 0.0
            total_logprobs = 0
            # for batch in self.kl_dataloader:
            #     inps, target_log_probs = [batch[0].squeeze(0), batch[1].squeeze(0)]
            #     logits = self.model(inps.to(self.device)).logits.reshape(-1, self.model.config.vocab_size)
            #     log_probs = F.log_softmax(logits.float(), dim=-1)
            #     kl = F.kl_div(log_probs, target_log_probs.to(self.device), reduction="batchmean", log_target=True).item()
            #     running_kl *= (total_logprobs / (total_logprobs + target_log_probs.numel()))
            #     running_kl += (target_log_probs.numel() / (total_logprobs + target_log_probs.numel())) * kl
            #     total_logprobs += target_log_probs.numel()
            #     del target_log_probs, logits, kl
            #     torch.cuda.empty_cache()
            # reward = -running_kl
            reward = -eval_ppl(self.model, self.test_data, self.sequence_length, device=self.device)

        return self.get_state(), reward, done, False, {}

In [5]:



class DebugEnvironment:
    def __init__(self):
        self.reset()

    def reset(self):
        self.cur_layer = 0
        self.n_layers = 5
        self.actions = []
        return self.get_state(), {}

    def get_state(self):
        state = {
            "state": torch.tensor([self.cur_layer / self.n_layers]*3, dtype=torch.float32),
            "action_mask": torch.ones(N, dtype=torch.float32)
        }
        # state = torch.tensor([self.cur_layer / self.n_layers]*3, dtype=torch.float32)
        return state

    def step(self, action:int):
        self.actions.append(action)
        self.cur_layer += 1
        done = self.cur_layer == self.n_layers
        reward = 0
        if done:
            target = [1,2,3,0,3]
            diff = 0
            for i in range(len(self.actions)):
                diff += abs(self.actions[i] - target[i])
            reward = max(0, 10 - diff)
        
        return self.get_state(), reward*100, done, False, {}


class Policy(nn.Module):
    def __init__(self, state_size:int, action_size:int, device:str=device):
        super(Policy, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.device = device
        self.base = nn.Sequential(
            nn.Linear(state_size, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        self.head = nn.Linear(256, action_size)
        self.uniform_init()

    def to(self, device:Optional[str]=None):
        if device is None:
            device = self.device
        self.device = device
        return super().to(device)

    def forward(self, state:Dict[str, torch.Tensor]) -> torch.Tensor:
        large_neg = torch.finfo(state.dtype).min
        action_mask = state[:, -self.action_size:]

        x = self.base(state)
        logits = self.head(x)
        logits = torch.where(action_mask.to(self.device) == 1, logits, large_neg)
        
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        return dist

    def uniform_init(self):
        bias = self.head.bias.data.detach().clone()
        bias = torch.ones_like(bias)*(1/self.action_size)
        self.head.bias.data.copy_(bias)

    @torch.no_grad()
    def act(self, state:Dict[str, torch.Tensor], deterministic=False) -> tuple[torch.Tensor, torch.Tensor]:
        if len(state.shape) == 1:
            state = state.unsqueeze(0)
        dist = self(state)
        action = dist.sample() if not deterministic else dist.mode
        log_prob = dist.log_prob(action)
        return action, log_prob


class Value(nn.Module):
    def __init__(self, state_size:int, device:str):
        super(Value, self).__init__()
        self.state_size = state_size
        self.model = nn.Sequential(
            nn.Linear(state_size, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        self.device = device

    def to(self, device:Optional[str]=None):
        if device is None:
            device = self.device
        self.device = device
        return super().to(device)

    def forward(self, state:torch.Tensor) -> torch.Tensor:
        return self.model(state)


class PolicyValue:
    def __init__(self, policy_model: nn.Module, value_model: nn.Module):
        self.policy_model = policy_model
        self.value_model = value_model
        self.device = policy_model.device

    def to(self, device:Optional[str]=None):
        if device is None:
            device = self.device
        self.device = device
        self.policy_model.to(device)
        self.value_model.to(device)
        return self

    def get_action_and_value(self, x, action=None):
        dist = self.policy_model(x)
        value = self.value_model(x)
        if action is None:
            action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value

    def get_value(self, x):
        return self.value_model(x)
    
    def get_dist(self, x):
        return self.policy_model(x)
    


def process_trajectory(trajectory, gamma, lam, device):
    lastgaelam = 0
    steps = len(trajectory)
    advantages = torch.zeros(steps).to(device)

    states, actions, rewards, log_probs, values = [], [], [], [], []
    for trans in trajectory:
        state, action, reward, log_prob, value = trans
        states.append(torch.tensor(state).to(device))
        actions.append(torch.tensor(action).to(device))
        rewards.append(torch.tensor(reward).to(device))
        log_probs.append(torch.tensor(log_prob).to(device))
        values.append(torch.tensor(value).to(device))

    values = torch.cat(values)
    states = torch.stack(states, dim=0)
    actions = torch.cat(actions)
    log_probs = torch.cat(log_probs)
    rewards = torch.stack(rewards)

    for t in reversed(range(steps)):
        if t == steps - 1:
            nextnonterminal = 0.0
            nextvalue = 0.0
        else:
            nextnonterminal = 1.0
            nextvalue = values[t+1]
        delta = rewards[t] + gamma * nextvalue * nextnonterminal - values[t]
        advantages[t] = lastgaelam = delta + gamma * lam * nextnonterminal * lastgaelam
        
    returns = advantages + values.squeeze()

    return states, actions, log_probs, values, returns, advantages


def process_trajectories(trajectories, gamma, lam, device):
    states, actions, log_probs, values, returns, advantages = [], [], [], [], [], []
    for trajectory in trajectories:
        s, a, lp, v, r, adv = process_trajectory(trajectory, gamma, lam, device)
        states.append(s)
        actions.append(a)
        log_probs.append(lp)
        values.append(v)
        returns.append(r)
        advantages.append(adv)
    states = torch.cat(states, dim=0)
    actions = torch.cat(actions, dim=0)
    log_probs = torch.cat(log_probs, dim=0)
    values = torch.cat(values, dim=0)
    returns = torch.cat(returns, dim=0)
    advantages = torch.cat(advantages, dim=0)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return states, actions, log_probs, values, returns, advantages


class Logger:
    def __init__(self):
        self.step = 0

    def log(self, metrics:Dict[str, float], step:Optional[int]=None):
        raise NotImplementedError
    
    def term(self, *args):
        print(*args)


class WandBLogger(Logger):
    def __init__(self, entity:str="ldfrancis", project_name:str="RLPress"):
        super().__init__()
        wandb.init(project=project_name, entity=entity)

    def log(self, metrics:Dict[str, float], step:Optional[int]=None):
        if step is None:
            step = self.step
            self.step += 1
        wandb.log(metrics, step=step)


class TerminalLogger(Logger):
    def __init__(self):
        super().__init__()

    def log(self, metrics:Dict[str, float], step:Optional[int]=None):
        if step is None:
            step = self.step
            self.step += 1
        print(f"Step {step}:")
        for k, v in metrics.items():
            print(f"\t{k} : {v}")
        print("\n")


class RLLearner:
    def __init__(self, policy_n_value:PolicyValue, gamma: float = 0.99, lam: float = 0.95, lr=1e-4, device: str = device):
        self.policy_n_value = policy_n_value
        self.gamma = gamma
        self.lam = lam
        self.device = device
        self.policy_optimizer = torch.optim.Adam(self.policy_n_value.policy_model.parameters(), lr=lr)
        self.value_optimizer = torch.optim.Adam(self.policy_n_value.value_model.parameters(), lr=1e-5)
        self.global_step = 0

    def __call__(self, trajectories, epochs:int=10):
        states, actions, log_probs, values, returns, advantages = process_trajectories(trajectories, self.gamma, lam=0.95, device=self.device)
        max_grad_norm = 0.5
        target_kl = 0.01
        self.global_step += len(states)
       
        bs = 32
        inds = np.arange(0, len(states))
        clip_coef = 0.2
        clipfracs = []
        
        policy_losses = []
        value_losses = []
        entropy_losses = []
        approx_kls = []
        old_approx_kls = []
        grad_steps = 0

        for epoch in range(epochs):
            stop_updates = False
            np.random.shuffle(inds)
            # self.policy_optimizer.zero_grad(); self.value_optimizer.zero_grad()
            for start in range(0, len(states), bs):
                end = min(start+bs, len(states))
                b_inds = inds[start:end]

                x = states[b_inds].to(self.device)
                a = actions[b_inds].to(self.device)

                dist = self.policy_n_value.get_dist(x)
                newlogprob = dist.log_prob(a)
                entropy = dist.entropy()
                newvalue = self.policy_n_value.get_value(x)
                
                logratio = newlogprob - log_probs[b_inds].to(self.device)
                ratio = logratio.exp()

                # Policy loss
                pg_obj1 = advantages[b_inds] * ratio
                pg_obj2 = advantages[b_inds] * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
                pg_loss = -torch.min(pg_obj1, pg_obj2).mean()

                # Value loss
                v_loss = 0.5 * ((newvalue - returns[b_inds])**2).mean()

                # Entropy loss
                entropy_loss = -entropy.mean()

                # Combined loss
                loss = pg_loss + 0.5 * entropy_loss

                self.policy_optimizer.zero_grad(); self.value_optimizer.zero_grad()
                loss.backward()
                v_loss.backward()
                nn.utils.clip_grad_norm_(self.policy_n_value.value_model.parameters(), max_grad_norm)
                nn.utils.clip_grad_norm_(self.policy_n_value.policy_model.parameters(), max_grad_norm)
                self.policy_optimizer.step(); self.value_optimizer.step()

                # Approx kl
                with torch.no_grad():
                    old_approx_kl = (-logratio).mean()
                    approx_kl = (ratio - 1 - logratio).mean()
                    clipfracs += [((ratio > (1 + clip_coef)) | (ratio < (1 - clip_coef))).float().mean().item()]

                grad_steps += 1
                policy_losses += [pg_loss.item()]
                value_losses += [v_loss.item()]
                entropy_losses += [entropy_loss.item()]
                approx_kls += [approx_kl.item()]
                old_approx_kls += [old_approx_kl.item()]

                # if approx_kl > target_kl:
                #     stop_updates = True
                #     break
            # self.policy_optimizer.step(); self.value_optimizer.step()
            if stop_updates:
                break

        learner_results = {
            "learner/losses/policy_loss": np.mean(policy_losses) if policy_losses else 0,
            "learner/losses/value_loss": np.mean(value_losses) if value_losses else 0,
            "learner/losses/entropy_loss": np.mean(entropy_losses) if entropy_losses else 0,
            "learner/losses/approx_kls": np.mean(approx_kls) if approx_kls else 0,
            "learner/losses/old_approx_kls": np.mean(old_approx_kls) if old_approx_kls else 0,
            "learner/losses/clipfrac": np.mean(clipfracs),
            "global_step": self.global_step,
        }

        return learner_results
        

class PolicyValueRollout:
    def __init__(self, env: DebugEnvironment, policy_n_value: PolicyValue):
        self.env = env
        self.policy_n_value = policy_n_value
        self.policy_model = policy_n_value.policy_model
        self.value_model = policy_n_value.value_model

    @torch.no_grad()
    def __call__(self, deterministic=False):
        state, _ = self.env.reset()
        done = False
        trajectory = []
        step = 0
        while not done:
            # import pdb; pdb.set_trace()
            state = torch.cat([state["state"], state["action_mask"]], dim=0).float().to(self.policy_n_value.device)
            if not deterministic:
                action, log_prob, _, value = self.policy_n_value.get_action_and_value(state.unsqueeze(0))
            else:
                action, log_prob = self.policy_model.act(state, deterministic=True)
                value = self.value_model(state.unsqueeze(0))
            next_state, reward, done, truncated, info = self.env.step(action.item())
            done = done or truncated
            trajectory.append((state, action, reward, log_prob, value))
            state = next_state
        return trajectory


class Trainer:
    def __init__(self, env: DebugEnvironment, policy_n_value: PolicyValue, logger: Logger, gamma: float = 0.99, lam: float = 0.95, lr: float = 1e-4):
        self.env = env
        self.policy_n_value = policy_n_value
        self.learner = RLLearner(policy_n_value, gamma=gamma, lam=lam, lr=lr, device=policy_n_value.device)
        self.rollout = PolicyValueRollout(env, policy_n_value)
        self.logger = logger
        self.best_score = 0

    def __call__(self, num_iters:int=100):
        for iter in range(num_iters):
            start_time = time.time()
            trajectories = [self.rollout() for _ in range(4)]
            learner_results = self.learner(trajectories)
            with torch.no_grad():
                trj =  self.rollout(deterministic=True)
                rew = sum(tran[2] for tran in trj)
                if rew > self.best_score:
                    self.best_score = rew
                    torch.save(self.policy_n_value.policy_model.state_dict(), "best_policy.pt")
                    print(f"New best model saved with score {self.best_score}")
            end_time = time.time()
            loss = learner_results["learner/losses/policy_loss"]
            self.logger.log({**learner_results, "Score": rew}, step=learner_results["global_step"])
            print(f"Iteration {iter+1}/{num_iters}, Loss: {loss:.4f}, Rew: {rew:.2f}, Global Step: {learner_results['global_step']}, Time: {end_time - start_time:.2f}s")
            del trajectories, trj
            torch.cuda.empty_cache()

N = 10
S = 3
# env = gym.make("CartPole-v1", max_episode_steps=200) #DebugEnvironment()
# env = DebugEnvironment()
env = Environment(model_name="meta-llama/Llama-2-7b-hf", num_samples=32, sequence_length=2048, target_sparsity=0.5)

policy_model = Policy(state_size=S+N, action_size=N, device=device)
value_model = Value(state_size=S+N, device=device)

policy_n_value = PolicyValue(policy_model, value_model)
policy_n_value.to(device)

logger = WandBLogger(entity="ldfrancis", project_name="RLPress")
trainer = Trainer(env, policy_n_value, logger=logger, gamma=1.0, lam=0.95, lr=1e-3)
# trainer(1000)

Loading FineWeb-Edu v2
Total tokens loaded: 65536


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]
[34m[1mwandb[0m: Currently logged in as: [33mldfrancis[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
policy_model.load_state_dict(torch.load("best_policy.pt"))
state, _ = env.reset()
done = False
score = 0
while not done:
    state = torch.cat([state["state"], state["action_mask"]], dim=0).float().to(device)
    with torch.no_grad():
        dist = policy_model(state.unsqueeze(0))
        action = dist.mode
    next_state, reward, done, truncated, info = env.step(action.item())
    done = done or truncated
    state = next_state
    score += reward
    print(action.item())
print(f"Achieved Score: {score}")

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


2
0
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
Achieved Score: -5.540319442749023


In [12]:
np.mean([m.weight.data.eq(0).float().mean().item() for layer in env.layers for m in layer.modules() if isinstance(m, nn.Linear)])

np.float64(0.1937019782886529)