In [None]:
from typing import Optional
import torch

class LayerWiseDummyOptimizer(torch.optim.Optimizer):
    def __init__(self, optimizer_dict=None, *args, **kwargs):
        dummy_tensor = torch.randn(1, 1)
        self.optimizer_dict = optimizer_dict
        super().__init__([dummy_tensor], {"lr": 1e-03})
    def zero_grad(self, set_to_none: bool = True) -> None: pass
    def step(self, closure=None) -> Optional[float]: pass

class LayerWiseDummyScheduler(torch.optim.lr_scheduler.LRScheduler):
    def __init__(self, *args, **kwargs):
        optimizer = LayerWiseDummyOptimizer()
        last_epoch = -1
        verbose = False
        super().__init__(optimizer, last_epoch, verbose)
    def get_lr(self): return [group["lr"] for group in self.optimizer.param_groups]
    def _get_closed_form_lr(self): return self.base_lrs

In [None]:
import torch.nn
import bitsandbytes as bnb
from transformers import get_constant_schedule
from functools import partial

from galore_torch import GaLoreAdamW8bit

def layerwise_8bitoptimizer(model, id_galore_params, param_groups):
    def optimizer_hook(p, optimizer, scheduler):
        if p.grad is None: return
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    lr = param_groups[1]["lr"]
    for p in model.parameters():
        if p.requires_grad:
            optimizer = GaLoreAdamW8bit(param_groups, lr = lr) if id(p) in id_galore_params else bnb.optim.Adam8bit([p], lr = lr)
            scheduler = get_constant_schedule(optimizer)
            p.register_post_accumulate_grad_hook(partial(optimizer_hook, optimizer=optimizer, scheduler=scheduler))
        
def load_galore_optimizer(model, layerwise, optim_params):
    galore_params = [module.weight for module_name, module in model.named_modules() 
        if isinstance(module, nn.Linear) and any(target_key in module_name for target_key in optim_params["target_modules_list"])]
    id_galore_params = {id(p) for p in galore_params}
    regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
    
    param_groups = [ 
        dict(params=regular_params),  
        dict(params=galore_params, **optim_params)
    ]
    if layerwise:
        layerwise_8bitoptimizer(model, id_galore_params, param_groups)
        # dummies because stepping is done with hooks added by layerwise_8bitoptimizer()
        optimizer, scheduler = LayerWiseDummyOptimizer(), LayerWiseDummyScheduler()
    else:
        optimizer = GaLoreAdamW8bit(param_groups, lr = param_groups[1]["lr"])
        scheduler = get_constant_schedule(optimizer)
    return optimizer, scheduler


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, set_seed, get_constant_schedule
from trl import SFTTrainer
from datasets import load_dataset
import torch, torch.nn as nn, uuid, wandb


use_galore = True
use_galore_layerwise = True
lr = 1e-5
rank = 1024
modelpath = "../models/TinyLlama-1.1B-intermediate-step-1431k-3T"
set_seed(42)
run_id = f"use_galore-{use_galore}-layerwise-{use_galore_layerwise}-rank{rank}-{str(uuid.uuid4())}"

model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    torch_dtype=torch.bfloat16,
    attn_implementation = "flash_attention_2",  
    use_cache = False,
)
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast = False)
train_dataset = load_dataset('imdb', split='train')

args = TrainingArguments(
    output_dir = run_id,
    optim = "adamw_8bit",
    logging_steps = 1, 
    max_steps = 100,
    per_device_train_batch_size = 16,
    learning_rate = lr,
    lr_scheduler_type = "constant",
    gradient_checkpointing = True,
)

if use_galore:
    galore_params = dict(
        target_modules_list=["attn", "mlp"], 
        lr = lr,
        rank=rank, 
        update_proj_gap=200, 
        scale=1, 
        proj_type="std"
    )
    optimizers = load_galore_optimizer(model, layerwise = use_galore_layerwise, optim_params = galore_params)
else:
    optimizers = (None, None)

trainer = SFTTrainer(
    model = model, 
    train_dataset = train_dataset,
    dataset_text_field = 'text',
    max_seq_length = 512,
    optimizers =  optimizers,
    args = args,
)

wandb.init(
    project = "galore-tests", 
    name = run_id,
).log_code(include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"))

trainer.train()