In [1]:
import wandb

    
api = wandb.Api()
project = "lora"
workspace = "username"

# Get our two main experiments so far
experiment_tags = ["baseline"]#"mt_eval"]

# get all runs that both: 1.  match any experiment tag and 2. are finished
runs = api.runs(f"{workspace}/{project}",
                {"$and": [
                    {"tags": {"$in": experiment_tags}},
                ]})


In [8]:
#for run in runs:
    #print(run.name) 

In [11]:
import os
import lightning as L
from lit_gpt.lora import GPT,  Config
from lit_gpt.utils import (
    check_valid_checkpoint_dir,
    lazy_load,
)
from pathlib import Path

import gc

import lightning as L
from functools import partial

from scripts.convert_lit_checkpoint import check_conversion_supported, copy_weights_llama, incremental_save

from lit_gpt.lora import GPT, Config, lora_filter, merge_lora_weights
from lit_gpt.model import Config as ModelConfig
from lit_gpt.utils import check_valid_checkpoint_dir,  lazy_load

import contextlib


def setup(run, checkpoints ="checkpoints/meta-llama/Llama-2-7b-hf",
           adapter_path = "downloads/", out_dir= "merged/", out_name="model.pth"):
    config = run.config
    checkpoints = Path(checkpoints)
    conf = Config.from_name(
            name=checkpoints.name,
            r=config["lora_r"],
            alpha=config["alpha"],
            dropout=config["lora_dropout"],
            to_query=config["lora_query"],
            to_key=config["lora_key"],
            to_value=config["lora_value"],
            to_projection=config["lora_projection"],
            to_mlp=config["lora_mlp"],
            to_head=config["lora_head"],
            joint_qkvp=config["joint_qkvp"],
            tensor_lora=config["tensor_lora"],
        )
    fabric = L.Fabric(devices=1, strategy="auto", precision="bf16-true", plugins=None)
    fabric.seed_everything(0)  # same seed for every process to init model (FSDP)
    
    with fabric.init_module(empty_init=(False)):
        model = GPT(conf)
    #tokenizer = Tokenizer(io.checkpoint_dir)
    check_valid_checkpoint_dir(checkpoints)

    if fabric.global_rank == 0:
        os.makedirs(out_dir, exist_ok=True)

    checkpoint_path = checkpoints / "lit_model.pth"
    #print(adapter_path)
    adapter = lazy_load(Path(adapter_path))
    base = lazy_load(Path(checkpoint_path))
    params = {**adapter.sd["model"], **base.sd}
    #print(params.keys())
    model.load_state_dict(params, strict=True)
    #model.eval()
    
    merge_lora_weights(model)

    save_path = out_dir / "lit_model.pth"
    fabric.print(f"Saving weights to {str(save_path)!r}")
    # remove lora parameters and the lora linear substring
    state_dict = {k.replace("linear.", ""): v for k, v in model.state_dict().items() if not lora_filter(k, v)}

    conf = ModelConfig.from_name(name=checkpoints.name)
    copy_fn = partial(copy_weights_llama, conf)

    pth_file = out_dir / out_name
    bin_file = pth_file.with_suffix(".bin")

    # initialize a new empty state dict to hold our new weights
    sd = {}

    with incremental_save(bin_file) as saver:
        with contextlib.ExitStack() as stack:
            lit_weights = state_dict.get("model", state_dict)
            check_conversion_supported(lit_weights)
            copy_fn(sd, lit_weights, saver=saver)
            gc.collect()
        saver.save(sd)


    
    #wandb.init(id=run.id, project=run.project, resume="allow")
    #wandb.log(results['results'])
    #wandb.finish()




In [12]:
for run in runs:
    print(run)
    weights = run.file(f"checkpoints/meta-llama/Llama-2-7b-hf/{run.name}_lora_finetuned.pth")
    # create the directory if it doesn't exist
    os.makedirs(f"download/{run.id}", exist_ok=True)
    out = weights.download(f"download/{run.id}", replace=True)
    setup(run, adapter_path=out.name,  out_dir=Path(f"merged/{run.id}"))
    

<Run username/lora/7c1k3q9r (finished)>


/home/chiche/miniconda3/envs/lorta/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/chiche/miniconda3/envs/lorta/lib/python3.9/sit ...
Seed set to 0


Using original matrix low rank adapters
Saving weights to 'merged/7c1k3q9r/lit_model.pth'


In [None]:
print(out.name)

download/wieemkak/checkpoints/meta-llama/Llama-2-7b-hf/tensor_lora_r_48_joint_heads_joint_layers_joint_qkvp_lora_finetuned.pth
