In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial
tqdm.pandas()

init_notebook_mode(connected=True)

Device: cuda



## Data Loading

In [None]:
%%bash
git clone https://github.com/asaparov/prontoqa.git
cd prontoqa
mkdir json
unzip generated_ood_data.zip -d json

In [4]:
import json

def load_json(file):
    with open(file, 'r') as f:
        return json.load(f)
    
data = load_json("prontoqa/json/1hop_ProofsOnly_random_noadj.json")

In [88]:
from random import randint

labels = ["is", "is not"]

def prepare_example(example, n_shots=5):
    prompt = ""
    
    for i in range(n_shots):
        change = randint(0, 1)
        prompt += example[f"in_context_example{i}"]['question'] + "\n"
        query = example[f"in_context_example{i}"]['query']
        
        if "is not" in query:
            label = "True" if change == 1 else "False"
            query = query.replace("is not", labels[change])
        else:
            label = "True" if change == 0 else "False"
            query = query.replace("is", labels[change])
    
        prompt += query.replace("Prove: ", "True or False: ") + " Think step-by-step." + "\n\n"
    
        for j, step in enumerate(example[f"in_context_example{i}"]['chain_of_thought']):
            prompt += f"({j+1}) {step}\n"
        
        prompt += f"\nAnswer: {label}\n\n\n"
    
    change = randint(0, 1)
    prompt += example["test_example"]['question'] + "\n"
    query = example["test_example"]['query']
    cot = example["test_example"]['chain_of_thought']
    
    if "is not" in query:
        label = "True" if change == 1 else "False"
        query = query.replace("is not", labels[change])
    else:
        label = "True" if change == 0 else "False"
        query = query.replace("is", labels[change])
    
    prompt += query.replace("Prove: ", "True or False: ") + " Think step-by-step." + "\n\n"
        
    return prompt, label, cot

In [89]:
prompts = [prepare_example(x, n_shots=5) for x in data.values()]
print(prompts[1])

('Yumpuses are grimpuses. Impuses are sterpuses. Sterpuses are dumpuses. Each sterpus is a numpus. Each impus is windy. Lempuses are melodic. Rompuses are lorpuses. Rompuses are not earthy. Numpuses are sweet. Rompuses are impuses. Sterpuses are transparent. Every yumpus is not dull. Lorpuses are luminous. Impuses are lempuses. Alex is a rompus. Alex is a yumpus.\nTrue or False: Alex is earthy. Think step-by-step.\n\n(1) Alex is a rompus.\n(2) Rompuses are not earthy.\n(3) Alex is not earthy.\n\nAnswer: False\n\n\nEvery impus is a tumpus. Sterpuses are jompuses. Impuses are grimpuses. Each sterpus is not large. Grimpuses are not fast. Every impus is not red. Lorpuses are not happy. Every tumpus is a rompus. Every tumpus is a lorpus. Every lorpus is a sterpus. Rompuses are metallic. Each tumpus is muffled. Every jompus is not aggressive. Gorpuses are sunny. Every lorpus is a gorpus. Shumpuses are earthy. Shumpuses are zumpuses. Every sterpus is a yumpus. Alex is a shumpus. Alex is a ste

### Model loading and evaluation

In [84]:
model = HookedTransformer.from_pretrained('gemma-2b')

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [None]:
import re

def extract_answer(text):
    pattern = r'Answer:\s*(.*)'
    match = re.search(pattern, text)
    
    if match:
        answer = match.group(1)
    else:
        answer = "NaN"

    return answer

scores = []

for i in tqdm(range(100)):
    prompt, label = prompts[i]
    #out = model.generate(prompt, 64, temperature=0, verbose=False)[len(prompt):]
    out = generate_chat(prompt, max_new_tokens=64, temperature=0, verbose=False)[len(prompt):]

    prediction = extract_answer(out)
    scores.append(prediction == label)

In [61]:
import numpy as np

print(f"Score: {np.array(scores).mean()}")

The model achieves an accuracy of 66% with 5 shots.

## Attribution analysis 