# Implementation of Algorithm Distillation

## Установка нужных библиотек

In [2]:
!pip install wandb
!pip install stable_baselines3
!pip install shimmy
!pip install gymnasium stable-baselines3
!pip install gym



## Вход в учетную запись Weights&Biases

In [None]:
!wandb login

## Импорт нужных модулей


In [3]:
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model
import json
import pickle
import gym
import numpy as np
import wandb
from torch.utils.data import Dataset




### Объявляем модель -- трансформер

In [4]:
class PolicyTransformerConfig(GPT2Config):
    """
    Конфигурация PolicyTransformer.
    """
    def __init__(
            self,
            hidden_size=128,
            state_dim=0,
            act_dim=0,
            act_num=0,
            max_ep_len=20,
            context_len=80,
            token_mask_prob=0.3,
            **kwargs
    ):
        """
        Args:
            state_dim: размерность пространства состояний
            act_dim: размерность пространства действий
            max_ep_len: максимальная длина эпизода

            act_num: количество действий в домене
            context_len: длина контекста, используемого в качестве входных данных для модели
            token_mask_prob: вероятность маскирования токена во время обучения


        """
        super().__init__(**kwargs)

        self.hidden_size = hidden_size

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.act_num = act_num
        self.max_ep_len = max_ep_len

        self.context_len = context_len
        self.token_mask_prob = token_mask_prob


class PolicyTransformer(nn.Module):
    """
    Transformer, который предсказывает следующее действие на основе истории, с GPT2 в кач-ве backbone.

    s_0, a_0, r_0,   s_1, a_1, r_1,  ...  s_T, a_T (masked), r_T (masked)

                                           |
                                           V

                                  [logit_1  logit_2  ...  logit_|A|]

                                           |
                                           V

                            Output (действие с максимальной вероятностью)
    """
    def __init__(self, config: GPT2Config, *args):
        super().__init__(*args)

        self.config = config

        self.hidden_size = config.hidden_size
        self.encoder = GPT2Model(config)

        self.embed_state = torch.nn.Linear(config.state_dim, config.hidden_size)
        self.embed_action = torch.nn.Linear(config.act_dim, config.hidden_size)
        self.embed_reward = torch.nn.Linear(1, config.hidden_size)
        self.embed_timestep = nn.Embedding(config.max_ep_len, config.hidden_size)

        self.predict_action = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, config.act_num),
        )

    def forward(
            self,
            states,
            actions,
            rewards,
            timesteps,
            attention_mask = None,
            **kwargs,
    ) -> dict:
        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

        state_embeddings = self.embed_state(states) # (batch_size, seq_length, hidden_size)
        action_embeddings = self.embed_action(actions) # (batch_size, seq_length, hidden_size)
        reward_embeddings = self.embed_reward(rewards) # (batch_size, seq_length, hidden_size)
        time_embeddings = self.embed_timestep(timesteps) # (batch_size, seq_length, hidden_size)

        # эмбеддинги времени обрабатываются как positional
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        reward_embeddings = reward_embeddings + time_embeddings

        # последовательность теперь выглядит как (s_1, a_1, r_1, s_2, a_2, ...)
        # что хорошо работает в авторегрессии, т.к. состояния предсказывают действия
        stacked_inputs = (
            torch.stack((state_embeddings, action_embeddings, reward_embeddings), dim=1) # (batch_size, 3, seq_length, hidden_size)
            .permute(0, 2, 1, 3) # (batch_size, seq_length, 3, hidden_size)
            .reshape(batch_size, 3 * seq_length, self.hidden_size)
        )

        device = stacked_inputs.device

        stacked_attention_mask = (
            torch.stack((attention_mask, attention_mask, attention_mask), dim=1) # (batch_size, 3, seq_length)
            .permute(0, 2, 1) # (batch_size, seq_length, 3)
            .reshape(batch_size, 3 * seq_length) # (batch_size, 3 * seq_length)
        )

        if self.training and self.config.token_mask_prob > 0:
            mask = (torch.rand(stacked_attention_mask.shape) > self.config.token_mask_prob).float().to(device)
            stacked_attention_mask = stacked_attention_mask * mask

        # всегда маскируем эмбеддинг вознаграждения во время последнего шага
        stacked_attention_mask[:, -2:] = 0

        # подаем в модель входные эмбеддинги (не индексы слов, как в NLP)
        encoder_outputs = self.encoder(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = encoder_outputs[0]

        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) # (batch_size, 3, seq_length, hidden_size)

        # используем эмбеддинг состояния на последнем шаге в кач-ве input, получаем распределение действий
        action_pred = self.predict_action(x[:, 0, -1])  # (batch_size, act_num)

        action_true = actions[:, -1, 0].long()  # (batch_size,)

        # получаем логарифмические вероятности прeдсказанных действий
        log_probs = F.log_softmax(action_pred, dim=1)
        nll_loss = F.nll_loss(log_probs, action_true)

        return {'action_pred': action_pred, 'loss': nll_loss}

    def predict(
            self,
            states=None,
            actions=None,
            rewards=None,
            timesteps=None,
            attention_mask=None,
            temperature=0.0,
            **kwargs,
    ):
        model_output = self.forward(states, actions, rewards, timesteps, attention_mask, **kwargs)

        if temperature == 0.:
            # если температура не инициализирована, то по argmax
            action_idx = torch.argmax(model_output['action_pred'], dim=1)
        else:
            # иначе, по распределению
            action_idx = torch.multinomial(F.softmax(model_output['action_pred'] / temperature, dim=1), num_samples=1)

        return action_idx.item()

### Функция для нахождения конфига DarkRoom

In [5]:
def find_config_file(env_id: str, alg: str):
    if env_id.startswith('DarkRoom'):
        return f"configs/DarkRoom-{alg}.json"
    else:
        raise ValueError(f"Неизвестная среда или алгоритм: {env_id, alg}")

In [57]:
!git clone https://github.com/me1nna/in-context-ad.git
# клоню со своего репозитория, вытаскиваю из папки data архив data.zip

Cloning into 'in-context-ad'...
remote: Enumerating objects: 111, done.[K
remote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 111 (delta 5), reused 12 (delta 5), pack-reused 99[K
Receiving objects: 100% (111/111), 413.12 MiB | 54.32 MiB/s, done.
Resolving deltas: 100% (47/47), done.
Updating files: 100% (28/28), done.


## В архиве -- датасет, envs, configs



## Подготовка датасета

In [59]:
!mv in-context-ad/data/data.zip data.zip

In [9]:
!unzip data.zip

Archive:  data.zip
   creating: data/
  inflating: data/get_data.ipynb     
  inflating: data/darkroom_normal_117.pkl  
  inflating: data/darkroom_normal_105.pkl  
  inflating: data/darkroom_normal_67.pkl  
  inflating: data/darkroom_normal_16.pkl  
  inflating: data/darkroom_normal_160.pkl  
  inflating: data/darkroom_normal_113.pkl  
  inflating: data/darkroom_normal_163.pkl  
  inflating: data/darkroom_normal_196.pkl  
  inflating: data/darkroom_normal_10.pkl  
  inflating: data/darkroom_normal_170.pkl  
  inflating: data/darkroom_normal_60.pkl  
  inflating: data/darkroom_normal_6.pkl  
  inflating: data/darkroom_normal_46.pkl  
  inflating: data/darkroom_normal_77.pkl  
  inflating: data/darkroom_normal_188.pkl  
  inflating: data/darkroom_normal_38.pkl  
  inflating: data/darkroom_normal_148.pkl  
  inflating: data/darkroom_normal_162.pkl  
  inflating: data/darkroom_normal_20.pkl  
  inflating: data/darkroom_normal_199.pkl  
  inflating: data/darkroom_normal_171.pkl  
  inflatin

In [16]:
class LifetimeDataset(Dataset):
    def __init__(self, lifetimes: list, context_len: int = 3):
        """
        args:
            lifetimes: Список, где каждый жизненный цикл представляет собой список траекторий
            context_len: Длина предыдущего контекста для предсказания
        """
        self.lifetimes = lifetimes
        self.context_len = context_len

        self.lifetime_lens = [len(trajs) - 1 for trajs in self.lifetimes]
        self.total_len = sum(self.lifetime_lens)

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        """
        Получает контекст до индекса idx
        """

        # находим индекс жизненного цикла и индекс шага
        for lifetime_idx in range(len(self.lifetimes)):
            if idx < self.lifetime_lens[lifetime_idx]:
                break
            idx -= self.lifetime_lens[lifetime_idx]

        step_idx = idx + 1

        # получаем context_len предыдущих шагов как контекст в том же жизненном цикле
        context = self.lifetimes[lifetime_idx][max(0, step_idx - self.context_len):step_idx]

        return {
            'states': [s[0] for s in context],
            'actions': [s[1] for s in context],
            'rewards': [s[2] for s in context],
            'timesteps': [s[3] for s in context],
        }

In [17]:
import os
import pickle
from tqdm.auto import tqdm

def collect_from_pkl(file_prefix: str, data_dir: str = "data", subsample_gap: int = 1, max_lifetimes: int = 1500):
    """
    Функция для получения всех жизненных циклов
    """
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f"Directory not found: {data_dir}")

    filenames = []
    lifetimes = []

    for file in os.listdir(data_dir):
        if file.startswith(file_prefix) and file.endswith('.pkl'):
            filenames.append(os.path.join(data_dir, file))

    print(f'Найдено {len(filenames)} файлов в директории {data_dir}:')

    for filename in tqdm(filenames):
        if len(lifetimes) >= max_lifetimes:
            break
        try:
            with open(filename, 'rb') as f:
                lifetime = pickle.load(f)
                lifetime = lifetime[::subsample_gap]
                lifetimes.append(lifetime)
        except EOFError:
            print(f'Не получилось загрузить {filename}.')

    print(f'Получено {len(lifetimes)} жизненных циклов.')

    return lifetimes

In [18]:
def collate_fn(batch):

    states = [torch.FloatTensor(s['states']) for s in batch]
    actions = [torch.FloatTensor(s['actions']) for s in batch]
    rewards = [torch.FloatTensor(s['rewards']) for s in batch]
    timesteps = [torch.tensor(s['timesteps']) for s in batch]

    max_length = max(len(s) for s in states)

    attention_mask = [torch.tensor([0] * (max_length - len(s['states']))
                                   + [1] * len(s['states']))
                      for s in batch]

    for idx in range(len(batch)):
        if len(states[idx]) < max_length:
            # если данные короче максимальной длины в батче, добавляем паддинг слева
            pad = (max_length - len(states[idx]), 0)

            states[idx] = F.pad(states[idx], (states[idx].dim() - 1) * (0, 0) + pad)
            actions[idx] = F.pad(actions[idx], (actions[idx].dim() - 1) * (0, 0) + pad)
            rewards[idx] = F.pad(rewards[idx], pad)
            timesteps[idx] = F.pad(timesteps[idx], pad)

        if actions[idx].dim() == 1: # (seq_len,)
            # действия тут -- скалярны, поэтому создаем dummy измерение
            actions[idx] = actions[idx].unsqueeze(-1)

        # dummy измерение для наград
        rewards[idx] = rewards[idx].unsqueeze(-1)

    return {
        'states': torch.stack(states), # (batch_size, seq_len, state_dim)
        'actions': torch.stack(actions), # (batch_size, seq_len, action_dim)
        'rewards': torch.stack(rewards), # (batch_size, seq_len, 1)
        'timesteps': torch.stack(timesteps), # (batch_size, seq_len)
        'attention_mask': torch.stack(attention_mask), # (batch_size, seq_len)

        'return_loss': True,
    }


In [19]:
def prepare_for_prediction(
        states: list,
        actions: list,
        rewards: list,
        timesteps: list,
        env: gym.Env,
        context_len: int = 10,
        device: str = 'cuda',
):
    """
    Подготавливаем данные для предсказаний
    """
    datum = {
        'states': states[-context_len:],
        'actions': actions[-context_len + 1:] + [env.action_space.sample()],
        'rewards': rewards[-context_len + 1:] + [0],
        'timesteps': timesteps[-context_len:],
    }
    model_input = collate_fn([datum])
    model_input['return_loss'] = False

    for k in model_input:
        if hasattr(model_input[k], 'to'):
            model_input[k] = model_input[k].to(device)

    return model_input

# Training AD

In [20]:
import math
import os
import pprint
import numpy as np

import torch
import wandb
from transformers import TrainingArguments, Trainer
import gym

import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning, module='gym')


def train_model(
        env_id: str,
        model_config: str,
        subsample_gap: int,
        file_prefix: str,
        output_dir: str,
        resume: bool,
        device: str,
) -> None:
    """
    Args:
        env_id: ID среды
        model_config: путь к config файлу Трансформера
        file_prefix: prefix of the pickle files that contain the lifetimes
        device: device to run the model on
    """

    lifetimes = collect_from_pkl(file_prefix, subsample_gap=subsample_gap)

    # Указываем путь к конфигурационному файлу в папке data/configs
    config_path = os.path.join('data', model_config)
    config = PolicyTransformerConfig.from_json_file(config_path)
    model = PolicyTransformer(config)

    model.to(device)

    split_idx = math.ceil(len(lifetimes) * 0.95)
    train_dataset = LifetimeDataset(lifetimes[:split_idx], context_len=config.context_len)
    if split_idx == len(lifetimes):
        warnings.warn("Not enough data to split into train and eval sets, using train set for eval")
        eval_dataset = train_dataset
    else:
        eval_dataset = LifetimeDataset(lifetimes[split_idx:], context_len=config.context_len)

    training_args = TrainingArguments(
      output_dir=output_dir,
      num_train_epochs=5,
      logging_steps=100,
      evaluation_strategy="steps",
      eval_steps=1000,
      save_steps=1000,
      per_device_train_batch_size=1024,
      per_device_eval_batch_size=1024,
      warmup_ratio=0.1,
      learning_rate=3e-4,
      lr_scheduler_type="cosine",
      optim="adamw_torch",
      max_grad_norm=1,
      fp16=True,
      gradient_accumulation_steps=2,
  )

    wandb.init(
        project='alg-distill',
        name=env_id + '-train',
        config=config.to_dict(),
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=collate_fn,
    )

    trainer.train(
        resume_from_checkpoint=resume,
    )

    trainer.save_model()



2024-08-11 06:18:11.878165: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-11 06:18:11.878300: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-11 06:18:12.018012: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [21]:
@torch.no_grad()
def eval_model(
        env_id: str,
        lifetime_num: int,
        episode_num: int,
        model_config: str,
        model_path: str,
        device: str,
        log_interval: int = 100,
        temperature: float = 1.,
        baseline: str = None,
) -> None:
    # Указываем путь к конфигурационному файлу в папке data/configs
    config_path = os.path.join('data', model_config)
    config = PolicyTransformerConfig.from_json_file(config_path)
    model = PolicyTransformer(config)

    # Проверка существования файла модели
    model_file_path = os.path.join(model_path, 'model.safetensors')
    if not os.path.exists(model_file_path):
        raise FileNotFoundError(f"Model file not found: {model_file_path}")

    # Загрузка модели
    from safetensors.torch import load_file
    model_state_dict = load_file(model_file_path)
    model.load_state_dict(model_state_dict)
    model.eval()
    model.to(device)

    context_len = model.config.context_len

    states = []
    actions = []
    rewards = []
    timesteps = []

    wandb.init(
        project='alg-distill',
        name=env_id + '-eval',
        config=config.to_dict(),
    )

    for lifetime_idx in range(lifetime_num):
        env = gym.make(env_id)
        ep_len = env.episode_length

        for i in range(episode_num):
            traj = []
            obs = env.reset()
            for t in range(ep_len):
                states += [obs]
                timesteps += [t]
                if baseline == 'random':
                    action = env.action_space.sample()
                else:
                    model_input = prepare_for_prediction(states, actions, rewards, timesteps,
                                                        env, context_len, device)
                    action = model.predict(**model_input, temperature=temperature) # temperature > 0 to allow exploration
                new_obs, reward, done, _ = env.step(action)
                actions += [action]
                rewards += [reward]
                traj += [(obs, action, reward, t)]
                obs = new_obs

                if done:
                    break

            mean_reward = np.mean(rewards[-ep_len:])
            print(f"Lifetime {lifetime_idx}, episode {i}, mean reward {mean_reward}")
            if (i+1) % log_interval == 0:
                log_interval_mean_reward = np.mean(rewards[-log_interval:])
                wandb.log({f'mean_reward': log_interval_mean_reward})
                env.render(trajectory=traj,
                           log_name=f'l{lifetime_idx}/e{i:05}')

# Train

In [None]:
from tqdm import tqdm


model_config = find_config_file(env_id='DarkRoom-v0', alg='dt')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'Using device: {device}')
print(f'Using model config: {model_config}')

### Регистрируем среду 😝




In [23]:
import sys

sys.path.append(os.path.abspath('data/envs'))

from darkroom import DarkRoom
from gym.envs.registration import register

register(
    id='DarkRoom-v0',
    entry_point='darkroom:DarkRoom',
    max_episode_steps=100,
)

## Распакуем последние чекпоинты!

In [60]:
!mkdir output
!unzip in-context-ad/outputs/output_11_08__3.zip

Archive:  in-context-ad/outputs/output_11_08__3.zip
   creating: checkpoint-66000/
  inflating: checkpoint-66000/model.safetensors  
  inflating: checkpoint-66000/scheduler.pt  
  inflating: checkpoint-66000/rng_state.pth  
  inflating: checkpoint-66000/training_args.bin  
  inflating: checkpoint-66000/optimizer.pt  
  inflating: checkpoint-66000/trainer_state.json  


In [61]:
!mv checkpoint-66000 output/checkpoint-66000

## Train

In [None]:
train_model(
    env_id='DarkRoom-v0',
    model_config=model_config,
    subsample_gap=1,
    file_prefix='darkroom_normal',
    output_dir='output',
    resume=True,
    device=device,
  )

# тренирую на 1500 lifetimes

## Архивируем и скачаем папку с чекпоинтами

In [54]:
!cd output && zip -r output_11_08__3.zip checkpoint-66000

  adding: checkpoint-66000/ (stored 0%)
  adding: checkpoint-66000/model.safetensors (deflated 8%)
  adding: checkpoint-66000/scheduler.pt (deflated 55%)
  adding: checkpoint-66000/rng_state.pth (deflated 25%)
  adding: checkpoint-66000/training_args.bin (deflated 52%)
  adding: checkpoint-66000/optimizer.pt (deflated 94%)
  adding: checkpoint-66000/trainer_state.json (deflated 79%)


In [55]:
from IPython.display import FileLink
FileLink(r'output/output_11_08__3.zip')