## Exploring LLM Models

This notebook contains the following:
1. Dumping the model architecture in a readable format
2. Capturing the callflow of various modules and submodule during LLM inference
3. Visualizing the result of [2] using an awfully hacky HTML solution 

In [None]:
import json
import time
import os
import gc
import hashlib
from pathlib import Path
from types import MethodType
import random

import torch
import torch.nn as nn
from transformers import pipeline

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
HF_TOKEN = os.environ["HF_TOKEN"]

Following are a set of simple functions to explore a pytorch model. The key function is `module.named_children`. This can be applied recursively for each sub-module and their sub-modules.

In [None]:
def traverse_layers(module: nn.Module, depth: int = 0) -> None:
    for name, module in module.named_children():
        print("  " * depth + f"{name}: {module.__class__.__name__}")
        traverse_layers(module, depth + 1)


def module_to_dict(
    module: nn.Module, depth: int = 0, with_module: bool = False
) -> dict:
    layers = {}
    for name, module in module.named_children():
        # Recursive step
        children = module_to_dict(module, depth + 1, with_module)

        layers[name] = {
            "depth": depth,
            "type": module.__class__.__name__,
            "children": children,
        }
        if with_module:
            layers[name]["module"] = module
    return layers

We can try the above on a small model. Using `Llama-3.2-1B` here. For LLAMA models, you will need permission from Meta to access them. Otherwise you can try with some other small model.

In [None]:
model_id = "meta-llama/Llama-3.2-1B"

pipe = pipeline(
    "text-generation",
    model=model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token=HF_TOKEN,
)
gc.collect()

In [None]:
module = pipe.model
model_dict = module_to_dict(module, with_module=False)

Path("tmp").mkdir(parents=True, exist_ok=True)
with open("tmp/model_arch.json", "w") as writer:
    json.dump(model_dict, writer, indent=4)

The output JSON would look like what is given below. You might want to use a decent JSON viewer (check online) to explore it.

In [None]:
!head -n 20 "tmp/model_arch.json"

The above does not capture the flow of information in the pipeline. We can mock the `.forward()` method of all the modules to give the call information. That way, we get to see what the modules getting called in order.

In [None]:
module_log_stack = []


def mock_forward(module: nn.Module, depth: int, name: str) -> None:
    if hasattr(module, "_original_forward_func"):
        return  # Already mocked

    # Take a copy of the original forward function
    module._original_forward_func = module.forward  # type: ignore

    def new_forward(self, *args, **kwargs):
        saved_stack = module_log_stack.copy()
        module_log_stack.clear()

        start_time = time.time()
        output = self._original_forward_func(*args, **kwargs)
        elapsed_time = time.time() - start_time

        module_name = self.__class__.__name__
        msg = {
            "depth": depth,
            "name": name,
            "module": module_name,
            "time": elapsed_time,
        }
        if module_log_stack:
            msg["children"] = module_log_stack.copy()

        saved_stack.append(msg)
        module_log_stack[:] = saved_stack

        return output

    # Use the wrapper `forward` function
    module.forward = MethodType(new_forward, module)
    # print("Mocked: ", str(module.__class__))


def apply_mocking(model: nn.Module, depth: int = 0):
    for name, module in model.named_children():
        if hasattr(module, "forward"):
            mock_forward(module, depth, name)
        apply_mocking(module, depth + 1)

In [None]:
module = pipe.model
apply_mocking(module)

Now we can run the text generation and `module_log_stack` will have the logs coming from all the `forward()` calls in a hierarchical way.

In [None]:
prompt = "Once upon a time"
module_log_stack.clear()
output = pipe(prompt, max_length=8, do_sample=True)

In [None]:
with open("tmp/call_logs.json", "w") as fp:
    json.dump(module_log_stack, fp, indent=4)

In [None]:
!head -n20 tmp/call_logs.json

We need a way to display the these in a structured way. Lets cook up a hacky HTML solution.

In [None]:
def get_color_for_key(key: str) -> str:
    if key == "root":
        return "hsl(0%, 100%, 100%)"
    hash_value = int(hashlib.md5(key.encode()).hexdigest(), 16)
    random.seed(hash_value + 5)
    hue = random.randint(0, 360)
    sat = random.randint(50, 100)
    light = random.randint(80, 94)
    return f"hsl({hue}, {sat}%, {light}%)"


def generate_html(entry: dict) -> str:
    color = get_color_for_key(entry["module"])
    style = (
        f"background-color: {color}; min-width: 100px; "
        "font-size:1.1rem; "
        "border: 1px solid #aaa; margin: 5px; "
        "padding-left: 5px; color: #000;"
    )
    children = entry.get("children", [])
    children_html = "\n".join(generate_html(child) for child in children)
    return (
        f'<div style="{style}">'
        f'{entry["name"]} {entry["module"]} ({entry["time"]:.4f}s){children_html}'
        '</div>'
    )

children = json.loads(Path("tmp/call_logs.json").read_text())
root = {"name": "root", "module": "root", "children": children, "time": float("NAN")}
html_content = generate_html(root)
Path("tmp/call_logs_viz.html").write_text(f"{html_content}")


The visualization is saved as HTML in the previous step. We can now display it inside the notebook.

In [None]:
from IPython.display import HTML

with open("tmp/call_logs_viz.html", "r", encoding="utf-8") as f:
    html_content = f.read()

HTML(
    f"""
    <div style="max-height: 950px; overflow-y: auto; width: 600px;">
    {html_content}
    </div>"""
)