In [None]:
import pandas as pd
import random
import re
import matplotlib.pyplot as plt
import time
from IPython.display import clear_output
from nnsight import CONFIG
from nnsight import LanguageModel
import os
import numpy as np
import torch
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import HTML, display
import matplotlib.cm as cm
import torch.nn.functional as F 
from torch import nn

YOUR_API_KEY = ""
CONFIG.set_default_api_key(YOUR_API_KEY)

HF_TOKEN = ''
os.environ['HF_TOKEN'] = HF_TOKEN
clear_output()

llamaInstruct = LanguageModel("meta-llama/Llama-3.3-70B-Instruct", device_map="auto")
llamaInstruct.model.eval()
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

torch.cuda.empty_cache()

head_dim = llamaInstruct.model.config.head_dim # 128
num_heads = llamaInstruct.model.config.num_attention_heads # 64
num_key_value_heads = llamaInstruct.model.config.num_key_value_heads # 8 


In [2]:
with open('mixed_datatables_updated.json', 'r') as f:
    loaded_datatables = json.load(f)

# Convert lists of row dicts back to DataFrames
mixed_datatables = {
    k: pd.DataFrame(v)
    for k, v in loaded_datatables.items()
}

# Load mixed_strings
with open('mixed_strings_updated.json', 'r') as f:
    mixed_strings = json.load(f)

In [4]:
def decimal_places(num):
    s = str(num)
    if '.' in s:
        return len(s.split('.')[1])
    else:
        return 0
def int_to_string(y_val):
    return f"{int(y_val)}" if float(y_val).is_integer() else f"{y_val}"

def find_subsequence(full, sub):
    for i in range(len(full) - len(sub) + 1):
        if full[i:i+len(sub)] == sub:
            return i, i + len(sub)
    return None, None

In [5]:
def generate_max_chat(datatable):
    messages = [
        {
            "role": "system",
            "content": (
                "You are a data scientist. Given a data table, return only the *maximum y-value*. "
                "Respond with only the y-value as it appears in the table. Do not explain. "
                "Do not return the x-value or the row index."
            )
        },
        {
            "role": "user",
            "content": (
                "What is the maximum y-value?\n"
                "Data Table:\n" + datatable + "\n\n"
            )
        }
    ]
    return messages

def generate_min_chat(datatable):
    messages = [
        {
            "role": "system",
            "content": (
                "You are a data scientist. Given a data table, return only the *minimum y-value*. "
                "Respond with only the y-value as it appears in the table. Do not explain. "
                "Do not return the x-value or the row index."
            )
        },
        {
            "role": "user",
            "content": (
                "What is the minimum y-value?\n"
                "Data Table:\n" + datatable + "\n\n"
            )
        }
    ]
    return messages


In [None]:
answer_path = os.path.join('synthetic_dataset/attention/largest', 'answers.json')
if os.path.exists(answer_path):
    with open(answer_path, "r") as f:
        answers = json.load(f)
else:
    answers = {}
layers = [i for i in range(80)]
for table_id in tqdm(mixed_strings):
    activation_path = os.path.join('synthetic_dataset/attention/largest', f'{table_id}.pt')
    if os.path.exists(activation_path):
        continue

    # generating the prompt 
    data_table = mixed_strings[table_id]
    prompt = generate_max_chat(data_table)
    prompt_text = llamaInstruct.tokenizer.apply_chat_template(
        prompt, tokenize=False, add_generation_prompt=True
    )

    # generating the answer
    with llamaInstruct.generate(prompt_text, remote = True, max_new_tokens=100):
        output = llamaInstruct.generator.output.save()

    text = llamaInstruct.tokenizer.batch_decode(output)
    match = re.search(r"<\|end_header_id\|>\n\n(.*?)<\|eot_id\|>", text[0], re.IGNORECASE)
    answer = match.group(1).strip().lower()

    # tokenizing the answer and getting indices
    tokenized = llamaInstruct.tokenizer(
        prompt_text,
        return_offsets_mapping=True,
        return_tensors='pt'
    )
    answer_token_ids = llamaInstruct.tokenizer(answer, add_special_tokens=False)['input_ids']
    flat_input_ids = tokenized['input_ids'][0].tolist()
    start_idx, end_idx = find_subsequence(flat_input_ids, answer_token_ids)
    answers[table_id] = {'value': answer, 'start_idx': start_idx, 'end_idx': end_idx}

    layer_dict = {}
    with llamaInstruct.trace(prompt_text, remote=True) as runner:
        for layer_idx in layers:
            q = llamaInstruct.model.layers[layer_idx].self_attn.q_proj.output[0].save()
            k = llamaInstruct.model.layers[layer_idx].self_attn.k_proj.output[0].save()

            cos = llamaInstruct.model.rotary_emb.output[0].save()
            sin = llamaInstruct.model.rotary_emb.output[1].save()
            layer_dict[layer_idx] = {'q': q, 'k': k, 'cos': cos, 'sin': sin}

    for layer_idx in layer_dict:
        for key in ['q', 'k', 'cos', 'sin']:
            layer_dict[layer_idx][key] = layer_dict[layer_idx][key].value
    torch.save(layer_dict, activation_path)
    
    with open('synthetic_dataset/attention/largest/answers.json', 'w') as f:
        json.dump(answers, f)
    

  0%|          | 0/100 [00:00<?, ?it/s]2025-07-07 12:53:50,850 63f27fcd-f7b5-402d-b535-86b2c6687ca7 - RECEIVED: Your job has been received and is waiting approval.
2025-07-07 12:53:56,974 63f27fcd-f7b5-402d-b535-86b2c6687ca7 - APPROVED: Your job was approved and is waiting to be run.
2025-07-07 12:53:57,284 63f27fcd-f7b5-402d-b535-86b2c6687ca7 - RUNNING: Your job has started running.
2025-07-07 12:53:58,698 63f27fcd-f7b5-402d-b535-86b2c6687ca7 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 2.98k/2.98k [00:00<00:00, 3.29MB/s]
2025-07-07 12:54:01,710 96525215-aa56-4444-a3c0-6a2101b54bad - RECEIVED: Your job has been received and is waiting approval.
2025-07-07 12:54:01,874 96525215-aa56-4444-a3c0-6a2101b54bad - APPROVED: Your job was approved and is waiting to be run.
2025-07-07 12:54:02,253 96525215-aa56-4444-a3c0-6a2101b54bad - RUNNING: Your job has started running.
2025-07-07 12:54:04,603 96525215-aa56-4444-a3c0-6a2101b54bad - COMPLETED: Your job has be