In [None]:
from typing import Optional
import torch

class LayerWiseDummyOptimizer(torch.optim.Optimizer):
    def __init__(self, optimizer_dict=None, *args, **kwargs):
        dummy_tensor = torch.randn(1, 1)
        self.optimizer_dict = optimizer_dict
        super().__init__([dummy_tensor], {"lr": 1e-03})
    def zero_grad(self, set_to_none: bool = True) -> None: pass
    def step(self, closure=None) -> Optional[float]: pass

class LayerWiseDummyScheduler(torch.optim.lr_scheduler.LRScheduler):
    def __init__(self, *args, **kwargs):
        optimizer = LayerWiseDummyOptimizer()
        last_epoch = -1
        verbose = False
        super().__init__(optimizer, last_epoch, verbose)
    def get_lr(self): return [group["lr"] for group in self.optimizer.param_groups]
    def _get_closed_form_lr(self): return self.base_lrs

In [None]:
import torch.nn
import bitsandbytes as bnb
from transformers import get_constant_schedule
from functools import partial

from galore_torch import GaLoreAdamW8bit
        
def load_galore_optimizer(model, lr, galore_config):    
    # f() to hook optimizer and scheduler to a parameter 
    def optimizer_hook(p, optimizer, scheduler):
        if p.grad is not None: 
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

    # Assign (hook) the Galore optimizer to parameters in target_modules_list, bnb.Adam8bit to all others
    galore_params = [
        (module.weight, module_name) for module_name, module in model.named_modules() 
        if isinstance(module, nn.Linear) and any(target_key in module_name for target_key in galore_config["target_modules_list"])
    ] 
    id_galore_params = {id(p) for p, _ in galore_params}
    # for _, name in galore_params: print(name)
    
    for p in model.parameters():
        if p.requires_grad:
            if id(p) in id_galore_params:
                optimizer = GaLoreAdamW8bit([dict(params=[p], **galore_config)], lr=lr)
            else:
                optimizer = bnb.optim.Adam8bit([p], lr = lr)
            scheduler = get_constant_schedule(optimizer)
            
            p.register_post_accumulate_grad_hook(partial(optimizer_hook, optimizer=optimizer, scheduler=scheduler))
            
    # return dummies, stepping is done with hooks 
    return LayerWiseDummyOptimizer(), LayerWiseDummyScheduler()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, set_seed, get_constant_schedule
from trl import SFTTrainer, setup_chat_format, DataCollatorForCompletionOnlyLM
from datasets import load_dataset
import torch, torch.nn as nn, uuid, wandb

set_seed(42)

modelpath = "../models/llama2-7b"
lr = 1e-5
galore_config = dict(
    target_modules_list=["attn", "mlp"], 
    rank=1024, 
    update_proj_gap=200, 
    scale=2, 
    proj_type="std"
)
run_id = f"galore-layerwise-rank{galore_config['rank']}-{str(uuid.uuid4())}"

model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    torch_dtype=torch.bfloat16,
    attn_implementation = "flash_attention_2",  
    device_map = "auto",
    use_cache = False,
)
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast = False)

model, tokenizer = setup_chat_format(model, tokenizer)
if tokenizer.pad_token in [None, tokenizer.eos_token]: 
    tokenizer.pad_token = tokenizer.unk_token

dataset = load_dataset("g-ronimo/oasst2_top4k_en")

training_arguments = TrainingArguments(
    output_dir = f"out_{run_id}",
    evaluation_strategy = "steps",
    label_names = ["labels"],
    per_device_train_batch_size = 16,
    gradient_accumulation_steps = 1,
    save_steps = 250,
    eval_steps = 250,
    logging_steps = 1, 
    learning_rate = lr,
    num_train_epochs = 3,
    lr_scheduler_type = "constant",
    gradient_checkpointing = True,
    group_by_length = False,
)

optimizers = load_galore_optimizer(model, lr, galore_config)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset['test'],
    data_collator = DataCollatorForCompletionOnlyLM(
        instruction_template = "<|im_start|>user", 
        response_template = "<|im_start|>assistant", 
        tokenizer = tokenizer, 
        mlm = False),
    max_seq_length = 256,
    dataset_kwargs = dict(add_special_tokens = False),
    optimizers =  optimizers,
    args = training_arguments,
)

wandb.init(
    project = "galore-7B", 
    name = run_id,
).log_code(include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"))

trainer.train()
trainer.save_model(f"out_{run_id}/model_trained")