In [None]:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch as t
import dictionary_learning.dictionary_learning.utils as utils


In [None]:
# model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
model_name = "google/gemma-2-2b"
model_name = "EleutherAI/pythia-70m-deduped"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=t.bfloat16)

In [None]:
# print(model.model.layers[1])
print(model.name_or_path)
print(model.config.architectures[0])
print(model)

In [None]:
def truncate_model(model: AutoModelForCausalLM, layer: int):
    """From tilde-research/activault
    https://github.com/tilde-research/activault/blob/db6d1e4e36c2d3eb4fdce79e72be94f387eccee1/pipeline/setup.py#L74
    This provides significant memory savings by deleting all layers that aren't needed for the given layer.
    You should probably test this before using it"""
    import gc


    total_params_before = sum(p.numel() for p in model.parameters())
    print(f"Model parameters before truncation: {total_params_before:,}")

    if (
        model.config.architectures[0] == "Qwen2ForCausalLM"
        or model.config.architectures[0] == "Gemma2ForCausalLM"
    ):

        removed_layers = model.model.layers[layer + 1 :]

        model.model.layers = model.model.layers[: layer + 1]

        del removed_layers
        del model.lm_head

        model.lm_head = t.nn.Identity()
    
    elif model.config.architectures[0] == "GPTNeoXForCausalLM":

        removed_layers = model.gpt_neox.layers[layer + 1 :]

        model.gpt_neox.layers = model.gpt_neox.layers[: layer + 1]

        del removed_layers
        del model.embed_out

        model.embed_out = t.nn.Identity()

    else:
        raise ValueError(f"Please add truncation for model {model.name_or_path}")
    


    total_params_after = sum(p.numel() for p in model.parameters())
    print(f"Model parameters after truncation: {total_params_after:,}")

    gc.collect()
    t.cuda.empty_cache()

    return model


print(t.cuda.memory_allocated() / 1e9, "GB actual")
print(t.cuda.memory_reserved()  / 1e9, "GB in cache")

model = truncate_model(model, 5)

print(t.cuda.memory_allocated() / 1e9, "GB actual")
print(t.cuda.memory_reserved()  / 1e9, "GB in cache")