In [None]:
import sys

def is_running_in_colab():
  return 'google.colab' in sys.modules

In [None]:
%pip install --upgrade pip
%pip install wandb
%pip install randomname
%pip install --upgrade torch
%pip install flash-attn --no-build-isolation
%pip install --upgrade datasets
%pip install --upgrade accelerate
%pip install --upgrade peft
%pip install --upgrade transformers
%pip install --upgrade huggingface_hub[hf_transfer]

In [None]:
import randomname

name = randomname.get_name()
notes = ""

print(name)

In [None]:
from huggingface_hub import login
import os

if is_running_in_colab():
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
else:
    hf_token = os.environ['HF_TOKEN']
login(token=hf_token)
hf_name = "username/FrozenLake-" + name # rename as needed to point to your repository on huggingface

In [None]:
import wandb
import os

if is_running_in_colab():
    from google.colab import userdata
    wandb_token = userdata.get('WANDB_TOKEN')
else:
    wandb_token = os.environ['WANDB_TOKEN']
wandb.login(key=wandb_token)
wandb.init(project="FrozenLakeTrainRL", name=name, notes=notes, save_code=False)

In [None]:
from peft import TaskType, get_peft_model, PeftModel, LoraModel, LoraConfig, IA3Config
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import HfApi
import os

base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
adapter_name = None #only specify if you want to pickup training - example: "username/FrozenLake-hidden-laminate"
adapter_subfolder = None #only specify if you want to pickup training - example: "checkpoint-340"
adapter_type = "ia3"

wandb.config.update({
    "base_model_name": base_model_name,
    "adapter_name": adapter_name,
    "adapter_subfolder": adapter_subfolder,
    "adapter_type": adapter_type,
})

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

hf = HfApi()
hf.snapshot_download(
    repo_id=base_model_name,
    cache_dir="./model/",
    ignore_patterns="*/*"
)

tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    cache_dir="./model/",
)
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="auto",
    cache_dir="./model/",
    attn_implementation="flash_attention_2",
)

print(f"Memory footprint: {base_model.get_memory_footprint()/1024/1024/1024} GB")

if adapter_name is not None:
    model = PeftModel.from_pretrained(
        base_model,
        adapter_name,
        adapter_name="<default>",
        subfolder=adapter_subfolder + "/<default>",
        #mixed=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
    )

    model.load_adapter(
        adapter_name,
        "<delayed>",
        subfolder=adapter_subfolder + "/<delayed>",
        is_trainable=True,
    )

else:
    if adapter_type == "ia3":
        peft_config = IA3Config(
            #target_modules="all-linear",
            #feedforward_modules=["w0"],
            task_type=TaskType.CAUSAL_LM,
            modules_to_save=[
                #"lm_head",
                #"embed_tokens"
            ],
        )

    elif adapter_type == "lora":
        peft_config = LoraConfig(
            #target_modules="all-linear",
            task_type=TaskType.CAUSAL_LM,
            modules_to_save=[
                #"lm_head",
                #"embed_tokens"
            ],
            r=32,
            lora_alpha=32,
            lora_dropout=0.0,
            bias="none",
            use_rslora=True,
        )

    else:
        raise Exception("Unknown adapter type")

    model = get_peft_model(
        base_model,
        peft_config=peft_config,
        adapter_name="<default>",
    )

    model.add_adapter(
        peft_config=peft_config,
        adapter_name="<delayed>",
    )

model.config.use_cache = False

model.print_trainable_parameters()

print(model)

In [None]:
from datasets import load_dataset

# Uncomment desired dataset
dataset_name = 'micahr234/FrozenLakeNotSlipperyPrepared20eto40ecombineto4096'
#dataset_name = 'micahr234/FrozenLakeNotSlipperyPrepared20eto40ecombineto4096low'
#dataset_name = 'micahr234/FrozenLakeNotSlipperyPrepared20eto40ecombineto4096hi'
percentage_of_train_data = 100
percentage_of_test_data = 100

wandb.config.update({
    "dataset_name": dataset_name,
    "percentage_of_train_data": percentage_of_train_data,
    "percentage_of_test_data": percentage_of_test_data,
})

ds_train = load_dataset(
    dataset_name,
    split=f"train[:{percentage_of_train_data}%]",
    cache_dir="./dataset/",
)
print(ds_train)

ds_test = load_dataset(
    dataset_name,
    split=f"test[:{percentage_of_test_data}%]",
    cache_dir="./dataset/",
)
print(ds_test)

In [None]:
import torch
from transformers import Trainer, TrainerCallback
from collections import defaultdict
import inspect

action_value_discount = 0.9
reward_scale = 30.0
reward_offset = 0.0
polyak_const = 0.1
episode_value_discount = 0.0
value_weight = 1.0
imitation_weight = 0.0
observation_weight = 0.0

#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

wandb.config.update({
    "action_value_discount": action_value_discount,
    "reward_scale": reward_scale,
    "reward_offset": reward_offset,
    "polyak_const": polyak_const,
    "episode_value_discount": episode_value_discount,
    "value_weight": value_weight,
    "imitation_weight": imitation_weight,
    "observation_weight": observation_weight,
})

def polyak_dict(source, target, tau):
    for sk, tk in zip(source.keys(), target.keys()):
        assert tk == sk
        p_source = source[tk]
        p_target = target[tk]
        p_target.data.copy_(tau * p_source.data + (1 - tau) * p_target.data)

def get_adapter_state_dict(model, adapter_name):
    state_dict = model.state_dict()
    adapter_state_dict = {k.replace(f".{adapter_name}", ".adapter_name"): state_dict[k] for k in state_dict if f".{adapter_name}" in k}
    return adapter_state_dict

def index_of_next_nonzero(x, default=-1):
    positions_of_ones = torch.nonzero(x).squeeze()
    indices = torch.arange(x.numel(), device=x.device)
    search_indices = torch.searchsorted(positions_of_ones, indices, right=False)

    y = torch.full(x.shape, device=x.device, fill_value=default, dtype=torch.int64)
    has_next_one = search_indices < positions_of_ones.numel()
    y[has_next_one] = positions_of_ones[search_indices[has_next_one]]

    return y

class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
        self._update_count = 0

    def dqn_loss(self, outputs_1, outputs_1_old, inputs, num_items_in_batch=None):
        input_ids = inputs['input_ids'][:, 1:].detach()
        logits_1 = outputs_1['logits'][:, :-1, ...]
        logits_1_old = outputs_1_old['logits'][:, :-1, ...].detach()
        action_ids_mask = inputs['action_ids_mask'][:, 1:].detach()
        end_ids_mask = inputs['end_ids_mask'][:, 1:].detach()
        cumulative_reward = inputs['cumulative_reward'][:, 1:].detach()
        env_id = inputs['env_id'][:, 1:].detach()

        metrics_list = []
        loss_list = []
        for am, em, cr, lo, olo, ii, vd in zip(action_ids_mask, end_ids_mask, cumulative_reward, logits_1, logits_1_old, input_ids, env_id):
            with torch.no_grad():
                nm = (~am).float()
                am = am.float()
                em = em.float()
                cr = cr.float()

                episode_num = torch.cumsum(em, dim=-1) - em

                action_indices = torch.nonzero(am).squeeze(-1)
                observation_indices = torch.nonzero(nm).squeeze(-1)

                current_action_indices = action_indices[:-1]
                next_action_indices = action_indices[1:]

                current_action_episode_num = episode_num[current_action_indices]
                next_action_episode_num = episode_num[next_action_indices]

                current_episode_end = (current_action_episode_num != next_action_episode_num).float()
                current_episode_reward = cr[current_action_indices]
                current_reward = reward_scale * (current_episode_reward - reward_offset) * current_episode_end

                current_action = ii[current_action_indices]

            next_max_value_indices = lo[next_action_indices, :].argmax(dim=-1)
            next_max_value = olo[next_action_indices, next_max_value_indices]

            target_value = current_reward + action_value_discount * (1.0 - current_episode_end) * next_max_value

            current_value = lo[current_action_indices, current_action]
            normalized_value_loss = ((current_value - target_value) / target_value.mean().detach())**2

            logprob = torch.nn.functional.log_softmax(lo, dim=-1)

            with torch.no_grad():
                action = ii[action_indices]
                observation = ii[observation_indices]

            action_logprob = logprob[action_indices, action]
            imitation_loss = -action_logprob

            observation_logprob = logprob[observation_indices, observation]
            observation_loss = -observation_logprob

            metrics_list.append({
                f'normalized_value_loss': normalized_value_loss,
                f'current_value': current_value,
                f'target_value': target_value,
                f'imitation_loss': imitation_loss,
                f'observation_loss': observation_loss,
            })

            loss_list.append({
                f'normalized_value_loss': normalized_value_loss * value_weight,
                f'imitation_loss': imitation_loss * imitation_weight,
                f'observation_loss': observation_loss * observation_weight,
            })

        metrics = {key: torch.cat([i[key] for i in metrics_list], dim=-1).mean().detach().cpu() for key in metrics_list[0]}

        loss_dict = {key: torch.cat([i[key] for i in loss_list], dim=-1).mean() for key in loss_list[0]}
        loss = sum(loss_dict.values())

        return loss, metrics

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if self.model_accepts_loss_kwargs:
            loss_kwargs = {}
            if num_items_in_batch is not None:
                loss_kwargs["num_items_in_batch"] = num_items_in_batch
            inputs = {**inputs, **loss_kwargs}

        filtered_inputs = {k: v for k, v in inputs.items() if k in ['input_ids', 'attention_mask']}

        with torch.no_grad():
            model.set_adapter("<delayed>")
            old_outputs = model(**filtered_inputs)

        model.set_adapter("<default>")
        outputs = model(**filtered_inputs)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        loss, metrics = self.dqn_loss(outputs, old_outputs, inputs, num_items_in_batch=num_items_in_batch)

        caller_function = inspect.stack()[1].function
        if caller_function == "training_step":
            train_eval = "train"
            prefix = ""
        elif caller_function == "prediction_step":
            train_eval = "eval"
            prefix = "eval_"
        else:
            raise ValueError(f"Invalid calling function {caller_function}")

        metrics = {f"{prefix}{k}": v for k, v in metrics.items()}
        self.store_metrics(metrics, train_eval=train_eval)

        if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
            loss *= self.accelerator.num_processes

        return (loss, outputs) if return_outputs else loss

    def store_metrics(self, metrics, train_eval = "train"):
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def log(self, logs, start_time = None):
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]

        return super().log(logs, start_time)

class polyak_update(TrainerCallback):
    def on_optimizer_step(self, *args, **kwargs):
        model = kwargs['model']
        source = get_adapter_state_dict(model, "<default>")
        target = get_adapter_state_dict(model, "<delayed>")
        polyak_dict(source, target, polyak_const)

In [None]:
from transformers import TrainingArguments
import transformers

learning_rate = 1e-2
lr_scheduler_type = "constant_with_warmup"
warmup_steps = 10
train_batch_size = 5
gradient_accumulation_steps = 2
num_train_epochs = 1
weight_decay = 0.01
eval_steps = 10
eval_batch_size = 5
eval_accumulation_steps = 2
run_eval = True
save_steps = eval_steps

#transformers.utils.logging.set_verbosity_debug()

training_config = TrainingArguments(
    output_dir=f"./trainer/{name}/",
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    warmup_steps=warmup_steps,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    eval_accumulation_steps=eval_accumulation_steps,
    num_train_epochs=num_train_epochs,
    fp16 = False,
    bf16 = True,
    fp16_full_eval = False,
    bf16_full_eval = True,
    optim="adamw_torch",
    weight_decay=weight_decay,
    logging_strategy="steps",
    logging_steps=1,
    save_strategy="steps",
    save_steps=save_steps,
    save_only_model=True,
    eval_strategy="steps" if run_eval else "no",
    eval_steps=eval_steps,
    disable_tqdm=True,
    report_to="wandb",
    overwrite_output_dir=True,
    run_name=name,
    eval_on_start=run_eval,
    push_to_hub=True,
    hub_always_push=True,
    hub_model_id=hf_name,
    hub_private_repo=True,
    hub_strategy="all_checkpoints",
    remove_unused_columns=False,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={
        "use_reentrant": False,
    },
)

trainer = CustomTrainer(
    model=model,
    args=training_config,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    callbacks=[polyak_update],
)

In [None]:
model.push_to_hub(hf_name, private=True)

In [None]:
wandb.run.log_code(".", name=name, include_fn=lambda path: path.endswith(".ipynb"))

In [None]:
trainer.train()

In [None]:
wandb.finish()