In [1]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
from tqdm import tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from jaxtyping import Float, Int
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
# import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

device = "cuda"

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
from transformers import LlamaForCausalLM, LlamaTokenizer

# TODO
MODEL_PATH='llama_model/hf_llama_model'

tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage = True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
model = HookedTransformer.from_pretrained("llama-7b", hf_model=hf_model, device=device, fold_ln=False, center_writing_weights=False, center_unembed=False)

model.tokenizer = tokenizer
model.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

Loaded pretrained model llama-7b into HookedTransformer


  0%|          | 0/20 [00:00<?, ?it/s]

'<s>The capital of Germany is a city of contrasts. It is a city of culture, history, and tradition. It is'

In [4]:
model.generate("How's it going, man?", max_new_tokens=20, temperature=0)

  0%|          | 0/20 [00:00<?, ?it/s]

"<s>How's it going, man?\nI'm a 20 year old guy from the UK. I'm a"

In [5]:
model.generate("The capital of Australia is", max_new_tokens=200, temperature=0)

  0%|          | 0/200 [00:00<?, ?it/s]

'<s>The capital of Australia is Canberra, the country’s largest inland city. It is located in the Australian Capital Territory (ACT), which is a federal territory. The city is located on the banks of the Molonglo River, 280 km south-west of Sydney.\nThe city was founded in 1913 as a compromise between the six Australian colonies. The city was built to be the capital of the country, and it was named after the Aboriginal word for “meeting place”.\nThe city is the seat of the Australian Parliament, and the Parliament House is the most important building in the city. The Parliament House is a modern building, which was built in the 1980s. It is a circular building, which is surrounded by a lake.\nThe city is also the seat of the Australian Government. The city is the home of the High Court of Australia, the Supreme Court of the Australian Capital Territory, and the Australian Def'

In [6]:
outputs, cache = model.run_with_cache("The capital of Australia is Canberra, the country’s largest inland city. It is located in the Australian Capital Territory (ACT), which is a federal territory. The city is located on the banks of the Molonglo River, 280 km south-west of Sydney.\nThe city was founded in 1913 as a compromise between the six Australian colonies. The city was built to be the capital of the country, and it was named after the Aboriginal word for “meeting place”.\nThe city is the seat of the Australian Parliament, and the Parliament House is the most important building in the city. The Parliament House is a modern building, which was built in the 1980s. It is a circular building, which is surrounded by a lake.\nThe city is also the seat of the Australian Government. The city is the home of the High Court of Australia, the Supreme Court of the Australian Capital Territory, and the Australian Def")

In [7]:
cache["z", 0].device # cache held in CUDA

device(type='cuda', index=0)

In [26]:
imshow(cache["attn", 20][0,30])

In [15]:
outputs, cache = model.run_with_cache("The capital of Australia is Canberra, the capital of the United States is")

In [22]:
imshow(cache["attn", 5][0,7])