In [32]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import h5py, sys, time, json
from tqdm import tqdm
from jokes import dataset

In [38]:
model_name = "google/gemma-2-2b-it"
print(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(model_name)
print('Time to load model', time.time()- start_time)
device = torch.device("mps" if torch.backends.mps.is_built() else "cpu")
model.to(device)
print(device)

google/gemma-2-2b-it


Loading checkpoint shards: 100%|██████████| 2/2 [00:21<00:00, 10.70s/it]


Time to load model 22.386863946914673
mps


In [39]:
activations = {}

def get_activation(name):
    def hook(model, input, output):
        # print(output)
        activations[name] = output[0].detach()
    return hook

model.model.layers[13].register_forward_hook(get_activation("layer_13"))
model.model.layers[14].register_forward_hook(get_activation("layer_14"))

<torch.utils.hooks.RemovableHandle at 0x11d8eeae0>

In [40]:
with open('jokes_data.json', 'r') as f:
    jokes_data = json.load(f)

In [42]:
joke_activations = []
for joke in jokes_data:
    inputs = tokenizer(joke['text'], return_tensors = "pt").to(device)
    with torch.no_grad():
        start_time = time.time()
        output = model(**inputs)
        print(f'Time for inference, at joke {joke['id']}' , time.time() - start_time)
        torch.save(output, f"outputs/output_{joke['id']}.pt")
    # Store the activations along with joke ID and type
    joke_activations.append({
        "id": joke["id"],
        "type": joke["type"],
        "text":joke["text"],
        "layer_13": activations["layer_13"],
        "layer_14": activations["layer_14"],
    })
    torch.save(joke_activations[-1], f'activations/joke_activations_{joke["id"]}.pt')
torch.save(joke_activations, 'joke_activations.pt')

(tensor([[[ 0.6860,  1.0977,  3.0149,  ..., -1.5590,  0.3709, -1.1333],
         [-2.3003,  2.6347,  0.3840,  ...,  2.5172, -0.6138,  2.6097],
         [ 3.7302, -2.0365, -2.2039,  ...,  2.8923, -2.7236, -1.1162],
         ...,
         [ 3.1971,  2.3497,  2.4329,  ...,  1.2405, -4.0154,  1.8298],
         [ 2.5727,  0.7036, -0.3393,  ..., -1.0310, -6.2209,  2.1541],
         [ 2.8339,  2.0653,  0.7723,  ..., -0.2841, -6.6302,  0.6895]]],
       device='mps:0'), None)
(tensor([[[ 1.2421,  0.7380,  1.1539,  ..., -1.5935,  0.1396, -1.5648],
         [-2.6604,  1.9934, -1.1163,  ...,  3.0065, -1.2558,  2.1968],
         [ 1.9155, -1.4863, -1.7710,  ...,  3.1962, -4.1634, -1.2415],
         ...,
         [ 2.6815,  1.2122,  1.9535,  ...,  0.5117, -3.0156,  1.8856],
         [ 2.8668, -2.2862, -2.7558,  ..., -1.2716, -7.0495,  2.7435],
         [ 3.4234,  1.9954, -2.7699,  ...,  1.2236, -4.7838,  0.7781]]],
       device='mps:0'), None)
Time for inference, at joke 0 19.360169887542725
(tens

In [43]:
for joke in jokes_data:
    print(joke)

{'id': 0, 'text': 'Methamphetamine is a powerful, highly addictive stimulant drug that affects the central nervous system', 'type': 0}
{'id': 1, 'text': "Why don't cannibals eat clowns? They taste funny.", 'type': -1}
{'id': 2, 'text': 'I used to play piano by ear, but now I use my hands.', 'type': 1}
{'id': 3, 'text': "Why is the little girl's ice cream melting? ... Because she was on fire", 'type': -1}
{'id': 4, 'text': "What's the hardest part of a vegetable to eat? The wheelchair.", 'type': -1}
{'id': 5, 'text': 'The Eiffel Tower was completed in 1889 and stands at a height of 324 meters.', 'type': 0}
{'id': 6, 'text': "Why don't some couples go to the gym? Because some relationships don't work out.", 'type': 1}
{'id': 7, 'text': 'In many martial arts, the red belt is often associated with high ranks', 'type': 0}
{'id': 8, 'text': "The error message indicates that you're trying to access a gated repository", 'type': 0}
{'id': 9, 'text': "Why don't skeletons fight each other? They d