In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from utils import *

import time
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "meta-llama/llama-2-7b-hf"

ctx_len = 4096
n_samples = 128
device = "cuda"

In [None]:
from typing import Dict, Tuple


N = 10 # possible sparsity levels (0.0 - 0.9)
S = 3 # number of state features
        

class Environment:
    def __init__(self, model_name:str, num_samples:int, sequence_length:int, target_sparsity:float=0.5)->None:
        self.model_name = model_name
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.target_sparsity = target_sparsity
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.possible_sparsities = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.device = device if torch.cuda.is_available() else "cpu"
        
        # initialize state
        self.load_calibration_data()
        self.reset()

    def load_calibration_data(self):
        # caliberation data
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.tokenizer = tokenizer
        num_tokens = self.num_samples * self.sequence_length
        self.calib_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=True)
        # self.test_data = get_fineweb_edu(num_tokens, self.sequence_length, tokenizer, train=False)
        _, self.test_data = get_w2_data(self.num_samples, self.sequence_length, tokenizer)

    @torch.no_grad()
    def init(self) -> None:
        # create model, tokenizer, and calibration data.
        # model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype=torch.float16, device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.model = model
        self.tokenizer = tokenizer

        # caliberation data
        test_data = self.test_data

        # env attributes
        self.action_mask = torch.ones(N)
        self.layers = model.model.layers
        self.num_layers = len(self.layers)
        self.current_layer = 0
        self.global_sparsity = 0.0
        self.layer_sparsities = [0.0] * self.num_layers
        self.pruning_info = {}

        # buffers
        self.inps = torch.zeros((self.num_samples, self.sequence_length, model.config.hidden_size), dtype=torch.float16, device=self.device)
        self.outs = torch.zeros_like(self.inps)
        self.inp_kwargs = {}

        # obtain input into the first decoder layer
        cache = model.config.use_cache
        model.config.use_cache = False
        inps = self.inps
        inp_kwargs = self.inp_kwargs
        class catch_inps(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
                self.num_inps = 0
            def forward(self, inp, **kwargs):
                nonlocal inps, inp_kwargs
                inps[self.num_inps] = inp
                inp_kwargs.update(kwargs)
                self.num_inps += 1
                raise Exception("caught inps. Stopping forward pass.")
        self.layers[0] = catch_inps(self.layers[0])
        for sample in self.calib_data:
            try:
                model(sample.to(self.device))
            except Exception as e:
                pass
        self.layers[0] = self.layers[0].module
        self.inps = inps
        self.inp_kwargs = inp_kwargs

        # save the log targets to a file for computing the KL divergence later
        i_batches = 0
        os.makedirs(f"logs/kl/{self.model_name}", exist_ok=True)
        batch_size = 4
        log_probs = []
        for j in range(self.num_samples):
            if os.path.exists(f"logs/kl/{self.model_name}/log_targets_{(j//batch_size)}_{batch_size}.pt"):
                i_batches = j // batch_size
                continue
            import pdb; pdb.set_trace()
            sample = test_data[j]
            logits = model(sample.to(self.device)).logits
            log_probs.append(F.log_softmax(logits.float(), dim=-1).reshape(-1, model.config.vocab_size).cpu())
            if j % batch_size == batch_size-1:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                print(f"Saved logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                log_probs = []
            elif j == self.num_samples - 1 and len(log_probs) > 0:
                log_probs = torch.cat(log_probs, dim=0).cpu()
                torch.save(log_probs, f"logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
                print(f"Saved logs/kl/{self.model_name}/log_targets_{i_batches}_{batch_size}.pt")
            i_batches = j // batch_size
            
        # create a dataloader for computing KL divergence later
        model_name = self.model_name
        class KLDataset(torch.utils.data.Dataset):
            def __init__(self):
                self.path_format = f"logs/kl/{model_name}"+"/log_targets_{}_{}.pt"
            def __len__(self):
                return i_batches + 1
            def __getitem__(self, idx):
                nonlocal batch_size
                samples = torch.cat(test_data[idx*batch_size:(idx+1)*batch_size], dim=0)
                log_probs = torch.load(self.path_format.format(idx, batch_size))
                return samples, log_probs
        self.kl_dataloader = torch.utils.data.DataLoader(KLDataset(), batch_size=1, shuffle=False)
        # print(f"KL dataloader with {len(self.kl_dataloader)} batches created.")
        model.config.use_cache = cache

    def prune_layer(self, layer_idx:int, sparsity:float)->None:
        if layer_idx in self.pruning_info:
            raise Exception(f"Layer {layer_idx} already pruned. Skipping.")
        
        layer = self.layers[layer_idx]
        sublayers = {name: module for name, module in layer.named_modules() if isinstance(module, nn.Linear)}
        wrapped_layers = {}
        for name, sublayer in sublayers.items():
            wrapped_layers[name] = WrappedGPT(sublayer)

        # obtain the input activations to each sublayer, computing the feature-wise norms
        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp
        handles = []
        for name in wrapped_layers:
            handles.append(sublayers[name].register_forward_hook(add_batch(name)))
        for j in range(self.num_samples):
            self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]
        for h in handles:
            h.remove()
        
        for name in sublayers:
            wrapped_layers[name].prune(sparsity)
            wrapped_layers[name].clean()

        # outputs after pruning
        for j in range(self.num_samples):
            with torch.no_grad():
                self.outs[j] = layer(self.inps[j].unsqueeze(0), **self.inp_kwargs)[0]

        # the output from this layer should be the input to the next layer
        self.inps, self.outs = self.outs, self.inps

        # done pruning this layer. Prepare some info about this layer's pruning
        obtained_sparsity = np.mean([l.weight.data.eq(0).float().mean().item() for l in sublayers.values()]).item()
        info = {
            "layer": layer_idx,
            "layer_target_sparsity": sparsity,
            "layer_obtained_sparsity": obtained_sparsity,
        }
        self.pruning_info[layer_idx] = info

    def reset(self) -> Dict[str, torch.Tensor]:
        if hasattr(self, "inps"):
            del self.inps, self.outs, self.inp_kwargs
            del self.kl_dataloader
            del self.model, self.tokenizer
        torch.cuda.empty_cache()
        self.init()
        return self.get_state()

    def get_state(self) -> Dict[str, torch.Tensor]:
        s = [self.global_sparsity, self.target_sparsity, self.current_layer / self.num_layers]
        if self.current_layer == 0:
            mask = [1] * len(self.possible_sparsities)
        else:
            mask = [1 if (sum(self.layer_sparsities[:self.current_layer]) + s) / self.current_layer <= self.target_sparsity else 0 for s in self.possible_sparsities]
        state = {
            "state": torch.tensor(s, dtype=torch.float32),
            "action_mask": torch.tensor(mask, dtype=torch.float32)
        }
        return state

    @torch.no_grad()
    def step(self, action:int)->Tuple[Dict[str, torch.Tensor], float, bool, Dict[str, object]]:
        sparsity = self.possible_sparsities[action]
        self.prune_layer(self.current_layer, sparsity)
        # update global sparsity
        self.layer_sparsities[self.current_layer] = sparsity
        self.current_layer += 1
        self.global_sparsity = np.mean(self.layer_sparsities[:self.current_layer])
        # compute reward
        reward = 0
        done = self.current_layer == self.num_layers
        if done:
            # compute KL divergence between the pruned and unpruned model.
            # the logits have been saved to a file during initialization.
            running_kl = 0.0
            total_logprobs = 0
            for batch in self.kl_dataloader:
                inps, target_log_probs = [batch[0].squeeze(0), batch[1].squeeze(0)]
                logits = self.model(inps.to(self.device)).logits.reshape(-1, self.model.config.vocab_size)
                log_probs = F.log_softmax(logits.float(), dim=-1)
                kl = F.kl_div(log_probs, target_log_probs.to(self.device), reduction="batchmean", log_target=True).item()
                running_kl *= (total_logprobs / (total_logprobs + target_log_probs.numel()))
                running_kl += (target_log_probs.numel() / (total_logprobs + target_log_probs.numel())) * kl
                total_logprobs += target_log_probs.numel()
                del target_log_probs, logits, kl
                torch.cuda.empty_cache()
            reward = running_kl
            # reward = eval_ppl(self.model, self.test_data, self.sequence_length, device=self.device)

        return self.get_state(), reward, done, {}

In [73]:
class DebugEnvironment:
    def __init__(self):
        self.reset()

    def reset(self):
        self.cur_layer = 0
        self.n_layers = 5
        self.actions = []
        return self.get_state()

    def get_state(self):
        state = {
            "state": torch.tensor([self.cur_layer / self.n_layers, 0.5, 0.5], dtype=torch.float32),
            "action_mask": torch.ones(N, dtype=torch.float32)
        }
        return state

    def step(self, action:int):
        self.actions.append(action)
        self.cur_layer += 1
        done = self.cur_layer == self.n_layers
        reward = 0
        if done and self.actions == [1,1,1,1,3]:
            reward = 10
        return self.get_state(), reward, done, {}

In [None]:
# RL Agent
from typing import Optional, Dict


class RLPolicy(nn.Module):
    def __init__(self, state_size:int, action_size:int, device:str=device):
        super(RLPolicy, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.device = device
        self.base = nn.Sequential(
            nn.Linear(S+N, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        self.head = nn.Linear(256, N)
        self.uniform_init()

    def to(self, device:Optional[str]=None):
        if device is None:
            device = self.device
        self.device = device
        return super().to(device)

    def forward(self, state:Dict[str, torch.Tensor]) -> torch.Tensor:
        s = state["state"]
        action_mask = state["action_mask"]
        large_neg = torch.finfo(s.dtype).min
        x = torch.cat([s, action_mask], dim=-1).to(self.device)
        x = self.base(x)
        logits = self.head(x)
        logits = torch.where(action_mask.to(self.device) == 1, logits, large_neg)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        return dist

    def uniform_init(self):
        bias = self.head.bias.data.detach().clone()
        bias = torch.ones_like(bias)*(1/self.action_size)
        self.head.bias.data.copy_(bias)

    @torch.no_grad()
    def act(self, state:Dict[str, torch.Tensor], deterministic=False) -> tuple[torch.Tensor, torch.Tensor]:
        state = {k: v.unsqueeze(0) for k, v in state.items()}
        dist = self(state)
        action = dist.sample() if not deterministic else dist.mode
        log_prob = dist.log_prob(action)
        return action, log_prob


class Value(nn.Module):
    def __init__(self, state_size:int, action_size:int):
        super(Value, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.model = nn.Sequential(
            nn.Linear(S, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def to(self, device:Optional[str]=None):
        if device is None:
            device = self.device
        self.device = device
        return super().to(device)

    def forward(self, state:torch.Tensor) -> torch.Tensor:
        return self.model(state)

In [80]:
class RLLearner:
    def __init__(self, model: nn.Module, gamma: float = 1.0, lr=1e-4, device: str = device):
        self.model = model
        self.gamma = gamma
        self.device = device
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def __call__(self, trajectories):
        R = 0
        returns = []
        for _, _, reward, log_prob in reversed(trajectories):
            R = reward + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns).to(self.device)
        baseline = 6.0
        returns = (returns - baseline)
        bs = 4
        loss = 0
        self.optimizer.zero_grad()
        scale = 1/len(returns)
        for i in range(0, len(returns), bs):
            j = min(i+bs, len(returns))
            trans = trajectories[i:j]
            r = returns[i:j]
            x = {
                    "state": torch.stack([t[0]["state"] for t in trans], dim=0).to(self.device),
                    "action_mask": torch.stack([t[0]["action_mask"] for t in trans], dim=0).to(self.device)
            }
            a = torch.cat([t[1] for t in trans]).to(self.device)
            dist = self.model(x)
            loss += (scale)*(-dist.log_prob(a) * r).sum()
        # print(next(self.model.parameters()).device)
        loss.backward()
        self.optimizer.step()
        # print("=============After backward=============")
        # print(next(self.model.parameters()).device)
        # sys.exit()

        # for (state, action, _, log_prob), R in zip(trajectories, returns):
        #     # self.model(state[])
        #     print(log_prob)
        #     policy_loss.append(log_prob * R)
        # self.optimizer.zero_grad()
        # policy_loss = torch.cat(policy_loss).sum()
        # policy_loss.backward()
        # self.optimizer.step()
        # return policy_loss.item()
        return loss.item()
    

In [None]:
class PolicyRollout:
    def __init__(self, env: Environment, policy_model: RLPolicy):
        self.env = env
        self.policy_model = policy_model

    def __call__(self, deterministic=False):
        state = self.env.reset()
        done = False
        trajectories = []
        while not done:
            action, log_prob = self.policy_model.act(state, deterministic=deterministic)
            next_state, reward, done, _ = self.env.step(action.item())
            trajectories.append((state, action, reward, log_prob))
            state = next_state
        return trajectories


class Trainer:
    def __init__(self, env: Environment, policy_model: RLPolicy, gamma: float = 0.99, lr: float = 1e-4):
        self.env = env
        self.policy_model = policy_model
        self.learner = RLLearner(policy_model, gamma=gamma, lr=lr, device=policy_model.device)
        self.rollout = PolicyRollout(env, policy_model)

    def __call__(self, num_episodes:int=100):
        for episode in range(num_episodes):
            start_time = time.time()
            trajectories = self.rollout()
            loss = self.learner(trajectories)
            kl = trajectories[-1][2]
            trj = self.rollout(deterministic=True)
            rew = trj[-1][2]
            end_time = time.time()
            print(f"Episode {episode+1}/{num_episodes}, Loss: {loss:.4f}, Rew: {rew:.2f}, Time: {end_time - start_time:.2f}s")
    

In [None]:
# N = 10
# S = 3
# environment = Environment(model_name, num_samples=n_samples, sequence_length=ctx_len, target_sparsity=0.5)
# policy_model = RLPolicy(state_size=S, action_size=N, device=device)
# policy_model.to(policy_model.device)
# trainer = Trainer(environment, policy_model, lr=1e-4)

Episode 1/1000, Loss: 6.1627, Rew: 9.00, Time: 0.02s
Episode 2/1000, Loss: 7.8744, Rew: 9.00, Time: 0.02s
Episode 3/1000, Loss: 6.7069, Rew: 9.00, Time: 0.02s
Episode 4/1000, Loss: 5.3994, Rew: 9.00, Time: 0.02s
Episode 5/1000, Loss: 7.5183, Rew: 9.00, Time: 0.02s
Episode 6/1000, Loss: 2.7919, Rew: 9.00, Time: 0.01s
Episode 7/1000, Loss: 6.7477, Rew: 9.00, Time: 0.02s
Episode 8/1000, Loss: 5.9487, Rew: 10.00, Time: 0.02s
Episode 9/1000, Loss: 4.3386, Rew: 9.00, Time: 0.02s
Episode 10/1000, Loss: 4.4731, Rew: 12.00, Time: 0.02s
Episode 11/1000, Loss: 2.8370, Rew: 9.00, Time: 0.01s
Episode 12/1000, Loss: 3.1000, Rew: 11.00, Time: 0.01s


Episode 13/1000, Loss: 7.2151, Rew: 10.00, Time: 0.02s
Episode 14/1000, Loss: 7.4841, Rew: 11.00, Time: 0.02s
Episode 15/1000, Loss: 5.1593, Rew: 12.00, Time: 0.02s
Episode 16/1000, Loss: 2.8151, Rew: 11.00, Time: 0.01s
Episode 17/1000, Loss: 5.1653, Rew: 10.00, Time: 0.02s
Episode 18/1000, Loss: 4.2159, Rew: 11.00, Time: 0.02s
Episode 19/1000, Loss: 7.8530, Rew: 14.00, Time: 0.02s
Episode 20/1000, Loss: 7.0413, Rew: 14.00, Time: 0.02s
Episode 21/1000, Loss: 2.7710, Rew: 16.00, Time: 0.01s
Episode 22/1000, Loss: 4.0561, Rew: 14.00, Time: 0.01s
Episode 23/1000, Loss: 4.5466, Rew: 15.00, Time: 0.02s
Episode 24/1000, Loss: 3.4174, Rew: 12.00, Time: 0.01s
Episode 25/1000, Loss: 3.6777, Rew: 11.00, Time: 0.01s
Episode 26/1000, Loss: 4.1702, Rew: 11.00, Time: 0.02s
Episode 27/1000, Loss: 5.9140, Rew: 10.00, Time: 0.02s
Episode 28/1000, Loss: 3.2166, Rew: 9.00, Time: 0.01s
Episode 29/1000, Loss: 6.2404, Rew: 8.00, Time: 0.02s
Episode 30/1000, Loss: 6.0291, Rew: 8.00, Time: 0.02s
Episode 31/10

In [None]:
print(id(next(policy_model.parameters())))

135943351178304


In [11]:
print(id(next(trainer.learner.model.parameters())))

135943351178304


In [86]:
# model_name = "meta-llama/llama-2-7b-hf"

# model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float16, device_map="auto")
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# w2_train_data, w2_test_data = get_w2_data(128, 4096, tokenizer)

# num_tokens = 2**23
# fw_train_data = get_fineweb_edu(num_tokens, 4096, tokenizer, train=True)



In [None]:
# prune_default(model, fw_train_data, 0.5, theta1=0.42, theta2=0.51, theta3=0.38, is_sparsegpt=True, device=device)

In [None]:
# eval_ppl(model, w2_test_data, 4096,  bs=1, device=device)

In [116]:
from ray.rllib.algorithms.ppo import PPOConfig

config = PPOConfig()
config.environment("CartPole-v1")

2025-09-30 09:59:33,480	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-09-30 09:59:33,560	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x7c365dc63c70>

AssertionError: Expected env to be a `gymnasium.Env` but got <class 'gymnasium.wrappers.common.PassiveEnvChecker'>

In [19]:
import gymnasium as gym

env = gym.make("CartPole-v1")  

In [None]:
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy
import os
import random
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter


@dataclass
class Args:
    exp_name: str = "ppo-test"
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "CartPole-v1"
    """the id of the environment"""
    total_timesteps: int = 500000
    """total timesteps of the experiments"""
    learning_rate: float = 2.5e-4
    """the learning rate of the optimizer"""
    num_envs: int = 4
    """the number of parallel game environments"""
    num_steps: int = 128
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.99
    """the discount fpolicy_model gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 4
    """the number of mini-batches"""
    update_epochs: int = 4
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.01
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""
    target_kl: float = None
    """the target KL divergence threshold"""

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""


def make_env(env_id, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        return env

    return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.value_model = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.policy_model = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

    def get_value(self, x):
        return self.value_model(x)

    def get_action_and_value(self, x, action=None):
        logits = self.policy_model(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.value_model(x)


if __name__ == "__main__":
    args = Args()#tyro.cli(Args)
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)],
    )
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    agent = Agent(envs).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # ALGO Logic: Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(args.num_envs).to(device)

    for iteration in range(1, args.num_iterations + 1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (iteration - 1.0) / args.num_iterations
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow

        for step in range(0, args.num_steps):
            global_step += args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
            next_done = np.logical_or(terminations, truncations)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                        writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                        writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.target_kl is not None and approx_kl > args.target_kl:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    envs.close()
    writer.close()

SPS: 1511
SPS: 2068
SPS: 2343
SPS: 2524
SPS: 2634
SPS: 2722
SPS: 2790
SPS: 2845
SPS: 2894
SPS: 2936
SPS: 2970
SPS: 3001
SPS: 3027
SPS: 3046
SPS: 3060
SPS: 3074
SPS: 3085
SPS: 3095
SPS: 3104
SPS: 3113
SPS: 3120
SPS: 3128
SPS: 3133
SPS: 3140
SPS: 3144
SPS: 3150
SPS: 3153
SPS: 3157
SPS: 3161
SPS: 3165
SPS: 3169
SPS: 3173
SPS: 3176
SPS: 3180
SPS: 3183
SPS: 3186
SPS: 3189
SPS: 3192
SPS: 3194
SPS: 3197
SPS: 3199
SPS: 3201
SPS: 3203
SPS: 3205
SPS: 3207
SPS: 3209
SPS: 3211
SPS: 3213
SPS: 3214
SPS: 3216
SPS: 3217
SPS: 3218
SPS: 3219
SPS: 3220
SPS: 3222
SPS: 3223
SPS: 3224
SPS: 3226
SPS: 3227
SPS: 3228
SPS: 3229
SPS: 3230
SPS: 3231
SPS: 3232
SPS: 3234
SPS: 3235
SPS: 3236
SPS: 3237
SPS: 3238
SPS: 3240
SPS: 3242
SPS: 3244
SPS: 3246
SPS: 3248
SPS: 3250
SPS: 3252
SPS: 3253
SPS: 3255
SPS: 3256
SPS: 3256
SPS: 3256
SPS: 3257
SPS: 3257
SPS: 3258
SPS: 3258
SPS: 3258
SPS: 3258
SPS: 3259
SPS: 3259
SPS: 3260
SPS: 3259
SPS: 3259
SPS: 3260
SPS: 3260
SPS: 3261
SPS: 3261
SPS: 3262
SPS: 3262
SPS: 3262
SPS: 3262


In [32]:
state, _ = env.reset()
done = False
step = 0
while not done:
    action = env.action_space.sample()
    state, reward, done, _, info = env.step(action)
    env.render()
    step += 1
    print(f"Step {step}: reward={reward}, done={done}")
env.close()

Step 1: reward=1.0, done=False
Step 2: reward=1.0, done=False
Step 3: reward=1.0, done=False
Step 4: reward=1.0, done=False
Step 5: reward=1.0, done=False
Step 6: reward=1.0, done=False
Step 7: reward=1.0, done=False
Step 8: reward=1.0, done=False
Step 9: reward=1.0, done=False
Step 10: reward=1.0, done=False
Step 11: reward=1.0, done=False
Step 12: reward=1.0, done=False
Step 13: reward=1.0, done=False
Step 14: reward=1.0, done=False
Step 15: reward=1.0, done=True


  gym.logger.warn(


array([ 0.20310922,  0.22362837, -0.21195334, -0.5994116 ], dtype=float32)

Thank you very much for your availability to mentor me. I am currently working on model compression (Pruning and Quantization) for LLMs and Automatic Speech Recognition (ASR) models, and compilation of the compressed models for efficient inference on GPUs. I am interested in the efficient inference of deep learning models, and I have been exploring model compression and compilation. 

Your mentorship would be highly valuable as I navigate this research area. I also intend to apply 