In [None]:
import gc
import time

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer

In [None]:
def start_memory_tracking():
    """Initialize GPU memory tracking."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    else:
        print("This notebook is intended for CUDA GPUs but CUDA is not available.")


def print_memory_usage():
    max_gpu_memory = torch.cuda.max_memory_allocated() / (
        1024**3
    )  # Convert bytes to GB
    print(f"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB")


def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(3)  # some buffer time to allow memory to clear
    torch.cuda.reset_peak_memory_stats()
    max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024**3)
    print(f"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB")

In [None]:
import os
import psutil
from threading import Thread


def monitor_memory_usage_in_gb(func):
    def wrapper(*args, **kwargs):
        process = psutil.Process(os.getpid())

        # Measure the baseline memory usage before running the function
        baseline_mem = process.memory_info().rss / 1024**3  # in GB

        # Start monitoring memory in a separate thread
        mem_usage = []
        done = False

        def monitor_memory():
            while not done:
                mem_usage.append(process.memory_info().rss / 1024**3)  # Convert to GB
                time.sleep(0.1)

        t = Thread(target=monitor_memory)
        t.start()

        # Run the function
        result = func(*args, **kwargs)

        # Stop monitoring
        done = True
        t.join()

        peak_mem_usage_gb = max(mem_usage) - baseline_mem
        print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")

        return result

    return wrapper

In [None]:
@monitor_memory_usage_in_gb
def load_model(model_name, device):
    config = AutoConfig.from_pretrained(model_name)

    with torch.device("meta"):
        model = AutoModelForCausalLM.from_config(config)

    model.load_state_dict(
        torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
        assign=True,
    )

    return model

In [None]:
model_name = "Qwen/Qwen2.5-7B-Instruct"
device = "cuda:0"

model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to(device)

In [None]:
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

if model_name not in OFFICIAL_MODEL_NAMES:
    OFFICIAL_MODEL_NAMES.append(model_name)

hooked_model = HookedTransformer.from_pretrained_no_processing(
    model_name,
    device=device,
    # dtype=torch.bfloat16,
    default_padding_side="left",
    # bf16=True
)

In [None]:
model.model.layers[0].self_attn.num_key_value_groups

In [None]:
dir(hooked_model.blocks[0].attn.b_Q)
torch.allclose(
    hooked_model.blocks[0].attn.b_Q.data.view(-1),
    model.model.layers[0].self_attn.q_proj.bias.data,
)

model.model.layers[0].self_attn.o_proj.weight.data.shape
print(dir(hooked_model.blocks[0].attn))
print(hooked_model.blocks[0].attn.W_O.shape)
print(hooked_model.blocks[0].mlp.W_out.shape)
dir(model.model.layers[0].mlp)
# model.model.layers[0].self_attn.head_dim

In [None]:
import plotly.express as px
import numpy as np

num_layers = len(hooked_model.blocks) * 2
mean = [[0] * num_layers for _ in range(num_layers)]
std = [[0] * num_layers for _ in range(num_layers)]

for i in range(num_layers):
    if i % 2 == 0:
        A = hooked_model.blocks[i // 2].attn.W_O
    else:
        A = hooked_model.blocks[i // 2].mlp.W_out
    for j in range(num_layers):
        if i == j:
            continue
        if j % 2 == 0:
            B = hooked_model.blocks[j // 2].attn.W_O
        else:
            B = hooked_model.blocks[j // 2].mlp.W_out
        m = A @ B.T
        mean[i][j] = m.mean()
        std[i][j] = m.std()