In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial
tqdm.pandas()

init_notebook_mode(connected=True)

Device: cuda



## Data Loading

In [None]:
%%bash
git clone https://github.com/asaparov/prontoqa.git
cd prontoqa
mkdir json
unzip generated_ood_data.zip -d json

In [4]:
import json

def load_json(file):
    with open(file, 'r') as f:
        return json.load(f)
    
data = load_json("prontoqa/json/1hop_ProofsOnly_random_noadj.json")

In [28]:
from random import randint

labels = ["is", "is not"]

def prepare_example(example, n_shots=3):
    prompt = ""
    
    for i in range(n_shots):
        change = randint(0, 1)
        prompt += example[f"in_context_example{i}"]['question'] + "\n"
        query = example[f"in_context_example{i}"]['query']
        
        if "is not" in query:
            label = "True" if change == 1 else "False"
            query = query.replace("is not", labels[change])
        else:
            label = "True" if change == 0 else "False"
            query = query.replace("is", labels[change])
    
        prompt += query.replace("Prove: ", "True or False: ") + " Think step-by-step." + "\n\n"
    
        for j, step in enumerate(example[f"in_context_example{i}"]['chain_of_thought']):
            prompt += f"({j+1}) {step}\n"
        
        prompt += f"\nAnswer: {label}\n\n\n"
    
    change = randint(0, 1)
    prompt += example["test_example"]['question'] + "\n"
    query = example["test_example"]['query']
    
    if "is not" in query:
        label = "True" if change == 1 else "False"
        query = query.replace("is not", labels[change])
    else:
        label = "True" if change == 0 else "False"
        query = query.replace("is", labels[change])
    
    prompt += query.replace("Prove: ", "True or False: ") + " Think step-by-step." + "\n\n"
        
    return prompt, label

In [29]:
prompts = [prepare_example(x) for x in data.values()]
print(prompts[0][0])

Dumpuses are numpuses. Impuses are jompuses. Each yumpus is not spicy. Every dumpus is mean. Lorpuses are snowy. Each lempus is not transparent. Numpuses are tumpuses. Numpuses are moderate. Every tumpus is luminous. Jompuses are not blue. Impuses are gorpuses. Every gorpus is not hot. Each dumpus is a yumpus. Every gorpus is a lempus. Lorpuses are sterpuses. Every impus is muffled. Every numpus is an impus. Gorpuses are rompuses. Polly is an impus. Polly is a lorpus.
True or False: Polly is muffled. Think step-by-step.

(1) Polly is an impus.
(2) Every impus is muffled.
(3) Polly is muffled.

Answer: True


Every lempus is a rompus. Each rompus is a jompus. Each jompus is a lorpus. Each rompus is a tumpus. Grimpuses are feisty. Jompuses are cold. Each dumpus is transparent. Each lempus is a dumpus. Rompuses are rainy. Vumpuses are gorpuses. Each tumpus is earthy. Every vumpus is sweet. Jompuses are grimpuses. Lempuses are angry. Alex is a rompus. Alex is a vumpus.
True or False: Alex 

### Model loading

In [10]:
model = HookedTransformer.from_pretrained('gemma-2b')

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [17]:
import re

def extract_answer():
    pattern = r'Answer:\s*(.*)'
    match = re.search(pattern, text)
    
    if match:
        answer = match.group(1)
    else:
        answer = "NaN"

    return answer

for i in range(100):
    prompt, label = prompts[i]
    out = model.generate(prompt, 64, temperature=0)
    
    
print(prompt[0], sep='', end='')
print('\033[1m' + out[len(prompt):] + '\033[0m')

  0%|          | 0/64 [00:00<?, ?it/s]

Dumpuses are numpuses. Impuses are jompuses. Each yumpus is not spicy. Every dumpus is mean. Lorpuses are snowy. Each lempus is not transparent. Numpuses are tumpuses. Numpuses are moderate. Every tumpus is luminous. Jompuses are not blue. Impuses are gorpuses. Every gorpus is not hot. Each dumpus is a yumpus. Every gorpus is a lempus. Lorpuses are sterpuses. Every impus is muffled. Every numpus is an impus. Gorpuses are rompuses. Polly is an impus. Polly is a lorpus.
True or False: Polly is muffled.

(1) Polly is an impus.
(2) Every impus is muffled.
(3) Polly is muffled.

Answer: True


Every lempus is a rompus. Each rompus is a jompus. Each jompus is a lorpus. Each rompus is a tumpus. Grimpuses are feisty. Jompuses are cold. Each dumpus is transparent. Each lempus is a dumpus. Rompuses are rainy. Vumpuses are gorpuses. Each tumpus is earthy. Every vumpus is sweet. Jompuses are grimpuses. Lempuses are angry. Alex is a rompus. Alex is a vumpus.
True or False: Alex is not rainy.

(1) A

In [18]:
label

'True'