In [1]:
import logging
import sys
import argparse
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import re

In [3]:
home_directory = os.path.expanduser('~')  # Gets the home directory
abstract_analogies_path = os.path.join(home_directory, 'AbstractAnalogies')
sys.path.append(abstract_analogies_path)

In [4]:
from models.llama3 import LLama3
from models.mistral7b import Mistral7B
from models.starling7b_beta import Starling7BBeta

In [7]:
SUPPORTED_MODELS = {
    'llama3': LLama3,
    'mistral7b': Mistral7B,
    'starling7b-beta': Starling7BBeta
}

In [14]:
# Args
args_model = 'llama3'
args_prompt = 'basic_prompt.txt'

In [9]:
model_class = SUPPORTED_MODELS[args_model]
model = model_class()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded llama with device cpu


In [12]:
dataset = pd.read_csv('../datasets/verbal_analogy/UCLA_VAT.csv')

In [17]:
# read prompt_templates/story_analogies/basic_prompt.txt
with open(f'../prompt_templates/verbal_analogy/{args_prompt}', 'r', encoding='utf-8') as file:
    prompt_template = file.read()
print(prompt_template)

The notation "A : B :: C : D" should be read as "A is to B like C is to D".

In the following example:

{A} : {B} :: {C} : ?

Replace the question mark with the correct analogy:

D: {D}

or

D': {D_prime}

Provide final answer in tags: <ans> D </ans> or <ans> D' </ans>


In [18]:
def parse_model_generation(generation: str):
    # 1. Try finding <ans> </ans> tags
    pattern = r"<ans>(.*?)</ans>"

    # Find all matches
    matches = re.findall(pattern, generation)
    if len(matches) == 1 and (matches[0].strip() == 'A' or matches[0].strip() == 'B'):
        return matches[0].strip()
    
    # 2. Default to None if answer not found
    return None

In [21]:
def inference(model, rel, A, B, C, D, D_prime, prompt_template, results):
    prompt = prompt_template.format(A=A, B=B, C=C, D=D, D_prime=D_prime)
    generation = model.forward(prompt)
    parsed_answer = parse_model_generation(generation)

    ambiguous = True
    if parsed_answer == None:
        logit_D, logit_D_prime = model.forward_logits(prompt + ' So the final answer is <ans> ')
        parsed_answer = 'D' if logit_D > logit_D_prime else 'D_prime'
    else:
        ambiguous = False
        logit_D = None
        logit_D_prime = None

    results.append({
        'relation': rel,
        'A': A,
        'B': B,
        'C': C,
        'D': D,
        "D'": D_prime,
        'full_prompt': prompt,
        'raw_generation': generation,
        'parsed_answer': parsed_answer,
        'ambiguous': ambiguous,
        'logit_D': logit_D,
        "logit_D'": logit_D_prime
    })

In [20]:
results = []
for _, row in tqdm(dataset.iterrows(), total=len(dataset)):
    rel = row['Relation']
    A = row['A']
    B = row['B']
    C = row['C']
    D = row['D']
    D_prime = row["D'"]
    results = []

    inference(model, rel, A, B, C, D, D_prime, prompt_template, results)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|                                                                                                                                              | 0/80 [01:04<?, ?it/s]


KeyboardInterrupt: 

In [23]:
import pandas as pd

In [30]:
file_path = '../results/verbal_analogies_logits_llama3_basic_prompt.csv'
data = pd.read_csv(file_path, index_col=0)

In [31]:
data

Unnamed: 0,relation,A,B,C,D,D',full_prompt,raw_generation,parsed_answer,ambiguous,logit_D,logit_D'
0,Synonym,easy,simple,sad,unhappy,happy,"The notation ""A : B :: C : D"" should be read a...","A classic analogy!\n\nThe pattern is: ""A is a ...",D,False,,
1,Synonym,hurry,rush,harm,injure,help,"The notation ""A : B :: C : D"" should be read a...","A classic analogy!\n\nThe pattern is: ""verb : ...",D,False,,
2,Synonym,rob,steal,cry,weep,laugh,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,D,False,,
3,Synonym,polite,courteous,angry,furious,happy,"The notation ""A : B :: C : D"" should be read a...","A classic analogy!\n\nThe pattern is: ""A is a ...",D,False,,
4,Synonym,beginner,novice,doctor,physician,heal,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,D,False,,
...,...,...,...,...,...,...,...,...,...,...,...,...
75,Category Members,bird,crow,sport,football,stadium,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,D,False,,
76,Category Members,weapon,sword,flower,daffodil,vase,"The notation ""A : B :: C : D"" should be read a...","A classic analogy!\n\nThe pattern is: ""A is a ...",D,False,,
77,Category Members,clothing,jacket,bird,pigeon,dog,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,D,False,,
78,Category Members,furniture,table,fruit,pear,tree,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is: <...,D,False,,


In [38]:
print(data.iloc[0]['raw_generation'])

A classic analogy!

The pattern is: "A is a more complex version of B, like C is a more complex version of D".

* easy is a more complex version of simple
* sad is a more complex version of unhappy

So, the correct answer is: <ans> D </ans>


In [39]:
print(data.iloc[0]['full_prompt'])

The notation "A : B :: C : D" should be read as "A is to B like C is to D".

In the following example:

easy : simple :: sad : ?

Replace the question mark with the correct analogy:

D: unhappy

or

E: happy

Provide final answer in tags: <ans> D </ans> or <ans> E </ans>


In [34]:
data[data['parsed_answer'] == 'E']

Unnamed: 0,relation,A,B,C,D,D',full_prompt,raw_generation,parsed_answer,ambiguous,logit_D,logit_D'
17,Synonym,help,aid,raise,lift,lower,"The notation ""A : B :: C : D"" should be read a...","A classic analogy!\n\nThe pattern is: ""verb : ...",E,False,,
29,Opposite,quiet,noisy,last,first,final,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is: <...,E,False,,
31,Opposite,laugh,cry,cheap,expensive,inexpensive,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is: <...,E,False,,
35,Opposite,teacher,student,calm,stormy,serene,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,E,False,,
39,Opposite,crazy,sane,bent,straight,crooked,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is: <...,E,False,,
40,Function,fly,bird,hop,rabbit,leg,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is: <...,E,False,,
44,Function,drive,car,burn,wood,fire,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,E,False,,
46,Function,squeeze,juice,shoot,gun,miss,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,E,False,,
48,Function,throw,ball,open,envelope,close,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,E,False,,
55,Function,run,horse,pull,tractor,muscle,"The notation ""A : B :: C : D"" should be read a...",A classic analogy!\n\nThe correct answer is:\n...,E,False,,


In [40]:
def calculate_accuracy(data):
    correct_predictions = (data['parsed_answer'] == 'D').sum()
    total_predictions = data.shape[0]
    
    accuracy = correct_predictions / total_predictions
    
    return accuracy

In [41]:
calculate_accuracy(data)

0.8625