In [None]:
import torch
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from transformers import ViTConfig
from stable_baselines3 import PPO
from typing import List

class FixedViTMDPEnvironment(gym.Env):
    """Fixed ViT MDP environment for pruning."""
    def __init__(self, model_name="google/vit-base-patch16-224", debug=True):
        super().__init__()
        self.debug = debug

        self.config = ViTConfig.from_pretrained(model_name)
        self.num_layers = self.config.num_hidden_layers
        self.num_heads = self.config.num_attention_heads

        self.observation_space = spaces.Box(
            low=0.0, high=1.0, shape=(4,), dtype=np.float32
        )
        self.action_space = spaces.Box(
            low=np.array([0.0, 0.0, 0.0], dtype=np.float32),
            high=np.array([0.7, 0.7, 1.0], dtype=np.float32),
            dtype=np.float32
        )
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_layer = 0
        self.total_reward = 0.0
        self.layer_states = [
            {
                'heads_remaining': 1.0,
                'ffn_size_ratio': 1.0,
                'is_active': True,
                'original_heads': self.num_heads,
                'original_ffn_size': 3072
            } for _ in range(self.num_layers)
        ]
        return self._get_current_state(), {}

    def _get_current_state(self):
        progress = self.current_layer / self.num_layers
        avg_heads = np.mean([ls['heads_remaining'] for ls in self.layer_states])
        avg_ffn = np.mean([ls['ffn_size_ratio'] for ls in self.layer_states])
        active_ratio = sum(ls['is_active'] for ls in self.layer_states) / self.num_layers
        return np.array([progress, avg_heads, avg_ffn, active_ratio], dtype=np.float32)

    def step(self, action):
        head_ratio, ffn_ratio, skip_prob = map(float, action)
        reward = self._apply_action(self.current_layer, head_ratio, ffn_ratio, skip_prob)
        self.total_reward += reward
        self.current_layer += 1
        terminated = self.current_layer >= self.num_layers

        if terminated:
            reward += self._calculate_final_bonus()
            self.total_reward += reward

        next_state = self._get_current_state()
        info = {
            'layer': self.current_layer - 1,
            'total_reward': self.total_reward,
            'compression_ratio': self._calculate_compression_ratio(),
            'active_layers': sum(ls['is_active'] for ls in self.layer_states)
        }
        return next_state, reward, terminated, False, info

    def _apply_action(self, idx, head_prune, ffn_prune, skip_prob):
        if idx >= len(self.layer_states):
            return 0.0
        state = self.layer_states[idx]
        prev_heads = state['heads_remaining']
        prev_ffn = state['ffn_size_ratio']
        state['heads_remaining'] = max(0.1, prev_heads * (1 - head_prune))
        state['ffn_size_ratio'] = max(0.1, prev_ffn * (1 - ffn_prune))
        if skip_prob > 0.6:
            state['is_active'] = False
        return self._calculate_step_reward(idx, head_prune, ffn_prune, skip_prob, prev_heads, prev_ffn, state['is_active'])

    def _calculate_step_reward(self, idx, head_prune, ffn_prune, skip_prob, prev_heads, prev_ffn, active):
        head_comp = head_prune * prev_heads
        ffn_comp = ffn_prune * prev_ffn
        skip_comp = 1.0 if not active else 0.0
        comp_reward = head_comp * 0.3 + ffn_comp * 0.5 + skip_comp * 1.0
        importance = self._layer_importance(idx)
        perf_penalty = head_prune * 0.4 * importance + ffn_prune * 0.6 * importance + skip_comp * 1.2 * importance
        return 10.0 * (comp_reward - perf_penalty)

    def _layer_importance(self, idx):
        pos = idx / (self.num_layers - 1)
        if pos < 0.3:
            return 1.5
        elif pos > 0.7:
            return 1.3
        return 1.0

    def _calculate_final_bonus(self):
        comp_ratio = self._calculate_compression_ratio()
        active_count = sum(ls['is_active'] for ls in self.layer_states)
        bonus = comp_ratio * 50.0
        penalty = 20.0 if active_count < self.num_layers * 0.3 else 0.0
        return bonus - penalty

    def _calculate_compression_ratio(self):
        total_orig = total_curr = 0.0
        for ls in self.layer_states:
            if ls['is_active']:
                head_p = ls['original_heads'] * ls['heads_remaining'] * 64 * 768
                ffn_p = ls['original_ffn_size'] * ls['ffn_size_ratio'] * 768
                total_curr += head_p + ffn_p
            total_orig += ls['original_heads'] * 64 * 768 + ls['original_ffn_size'] * 768
        return max(0.0, min(1.0, 1.0 - (total_curr / total_orig)))

def train_fixed_agent():
    env = FixedViTMDPEnvironment(debug=False)
    model = PPO("MlpPolicy", env, verbose=1, n_steps=512, batch_size=32, n_epochs=5, learning_rate=3e-4)
    model.learn(total_timesteps=2000)
    for ep in range(3):
        obs, _ = env.reset()
        done = False
        total_reward = 0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, _, info = env.step(action)
            total_reward += reward
        print(f"Episode {ep+1}: Reward={info['total_reward']:.2f}, Compression={info['compression_ratio']:.2%}, Active Layers={info['active_layers']}/{env.num_layers}")

if __name__ == "__main__":
    train_fixed_agent()
