In [1]:
#!/usr/bin/env python
# coding: utf-8

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import notebook_login
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
import datasets
import transformers
from datasets import load_dataset,load_from_disk
from evaluate import load
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader 
from tqdm import tqdm
import emoji
import argparse
from peft import PeftModel    




# 
# MODEL_NAME = "meta-llama/Llama-2-7b-hf"

def add_args(parser: argparse.ArgumentParser):

    parser.add_argument('--checkpoint_dir',
                            type=str,
                            default='experiments/Mistral-7B-v0.1/checkpoint-1000')

    parser.add_argument('--model_name',
                            type=str,
                            default='mistralai/Mistral-7B-v0.1')

    parser.add_argument('--save_file',
                                type=str,
                                default='pred_output.txt')
    
    parser.add_argument('--batch_size',
                            type=int,
                            default=32)
    
    parser.add_argument('--prompt',
                            type=str,
                            default="""
Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of charachters in the answer should be same as the number in the parenthesis. Just output the answer only. Do not output any explanitions, just the words in the answer.
 
### Input:
Desk register taken no further than Ozzie? (7)

### Output:
rolltop

### Input:
Henry has books stolen (3)

### Output:
hot
""")
    
    parser.add_argument('--n_shots',
                            type=int,
                            default=0)
    
    parser.add_argument('--num_examples',
                            type=int,
                            default=100)
    parser.add_argument('--dataset_path',
                            type=str,
                            default='data/unique_targets')
    parser.add_argument('--dataset_type',
                            type=int,
                            default=0)




def concat_length(example):

    example["clue"] = f'{example["clue"]} ({example["orig_lengths"]})'

    return example


# DEFAULT_SYSTEM_PROMPT = """
# Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of charachters in the answer should be same as the number in the parenthesis. Just output the answer only. Do not output any explanitions, just the words in the answer.
 
# ### Input:
# Desk register taken no further than Ozzie? (7)

# ### Output:
# rolltop

# ### Input:
# Henry has books stolen (3)

# ### Output:
# hot
# """.strip()


# def generate_training_prompt(
#     clue: str, prompt: str = DEFAULT_SYSTEM_PROMPT
# ) -> str:
    

#     return f"""### Instruction: {prompt}

# ### Input:
# {clue.strip()}

# """.strip()
     




def map_prompt(ex, base_prompt, shots):


    p = ''

    #add base prompt
    p = f'### Instruction: {base_prompt}\n\n'

    for shot in shots:
        p += f'### Input:\n{shot["clue"]}\n\n### Output:\n{shot["soln_with_spaces"]}\n\n'


    p+= f'### Input:\n{ex["clue"]}'


    ex['prompt'] = p
    return ex




def inference(prompts, tokenizer, generation_config, model):
    
   
    encoding = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)


    answer_lengthes = []

    for t in prompts:
        l = t.split('\n')[-1]
        answer_lengthes. append( l[l.rfind("(")+1:l.rfind(")")].split(',')) 

    answer_lengthes =  [ list(map(int, answer_lengthes[i]))  for i in range(len(answer_lengthes))] 

    # print(answer_lengthes)

    with torch.no_grad():
        outputs = model.generate(
            **encoding,
            max_new_tokens=64,
            do_sample=True,
            temperature=0.00001,
            pad_token_id=tokenizer.eos_token_id,
            generation_config=generation_config,
        )  

    answer_tokens = outputs[:, encoding.input_ids.shape[1] :]
    output_text = tokenizer.batch_decode(answer_tokens, skip_special_tokens=True)

    

    return output_text, answer_lengthes

In [2]:

    


parser = argparse.ArgumentParser('Eval LLMs on crossword solving')

add_args(parser)
args, _ = parser.parse_known_args()
# MODEL_NAME = "mistralai/Mistral-7B-v0.1"

for arg in vars(args):
    print(arg, getattr(args, arg))

MODEL_NAME = args.model_name
batch_size = args.batch_size
prompt = args.prompt
num_examples = args.num_examples
save_file = args.save_file

dataset_path = args.dataset_path





# val_dataset = load_dataset('json', data_files=dataset_path, field="val",split="train")

if args.dataset_type:
    val_dataset = load_dataset('json', data_files=dataset_path, field="train",split="train")
    val_dataset = val_dataset.map(concat_length)
    unique_answers = np.unique(val_dataset['soln'])
    val_dataset = val_dataset.select_columns(['soln_with_spaces', 'clue' ])


else:
    val_dataset = load_from_disk(dataset_path)
    val_dataset = val_dataset['test']
    unique_answers = np.unique(val_dataset['labels'])
    print(val_dataset)

    val_dataset = val_dataset.rename_column('labels', 'soln_with_spaces')


idx= np.random.randint(0,len(val_dataset),args.n_shots)

shots = val_dataset.select(idx)

for shot in shots:
    print(shot['clue'], shot['soln_with_spaces'])

    
val_dataset = val_dataset.map(map_prompt,fn_kwargs={"base_prompt": prompt,"shots":shots})



print(f' total number of examples: {len(val_dataset)},    number of unique answers: {len(unique_answers)}')

if num_examples == 0:
    num_examples = len(val_dataset)



val_dataloader = DataLoader(val_dataset.select(range(num_examples)),batch_size = batch_size)




model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    return_dict=True,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

if args.checkpoint_dir:
    adapter_checkpoint  = args.checkpoint_dir
    model = PeftModel.from_pretrained(model, adapter_checkpoint)


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
acc_metric = load("accuracy")


model = model.eval()
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)









# Define PAD Token = BOS Token
tokenizer.pad_token = tokenizer.bos_token
model.config.pad_token_id = model.config.bos_token_id


predictions = []
labels = []
original_predictions = []

torch.cuda.empty_cache()



checkpoint_dir experiments/Mistral-7B-v0.1/checkpoint-1000
model_name mistralai/Mistral-7B-v0.1
save_file pred_output.txt
batch_size 32
prompt 
Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of charachters in the answer should be same as the number in the parenthesis. Just output the answer only. Do not output any explanitions, just the words in the answer.
 
### Input:
Desk register taken no further than Ozzie? (7)

### Output:
rolltop

### Input:
Henry has books stolen (3)

### Output:
hot

n_shots 0
num_examples 100
dataset_path data/unique_targets
dataset_type 0
Dataset({
    features: ['labels', 'clue'],
    num_rows: 5629
})
 total number of examples: 5629,    number of unique answers: 5629


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:

for batch in tqdm(val_dataloader):

    prompts = batch['prompt']

    # for x in prompts:
    #     print(x)   
    # break

    # labels.extend (batch['soln_with_spaces'])
    ans = []

    output_text, answer_lengths = inference(prompts=prompts, tokenizer=tokenizer, generation_config=generation_config, model=model)
    


    # print(output_text)
    # break
    for i,t in enumerate(output_text):

        lines = t.split('\n')
        for j,l in enumerate(lines):
            if l=='### Response:' or l=='### Output:':
                labels.append( batch['soln_with_spaces'][i].lower())

                ## Cut the answer to the length of the answer given in the clue
                answer = []
                original_words = lines[j+1].lower().split(' ')
                if len(original_words) >= len(answer_lengths[i]):
                    for idx, length in enumerate(answer_lengths[i]):

                        answer.append(original_words[idx][:length])


                    predictions.append(' '.join(answer))
                else:
                    predictions.append(lines[j+1].lower())

                original_predictions.append(lines[j+1].lower())


                break
        # print( answer_lengths[i])
        
        # print(output_text)
        # break

print(len(predictions), len(labels))
assert (len(predictions) == len(labels))


correct = 0
length_error =0



save_file = 'outputs/' + 'mistral_1k_unique.txt'
with open(save_file, 'w') as f:
    for original,pred,label in zip(original_predictions,predictions,labels):
    # for pred,label in zip(predictions,labels):

        pred  = " ".join(pred.split())
        label = " ".join(label.split())

        correctly_predicted = False
        if pred == label:
            correct +=1
            correctly_predicted = True

        if len(pred) != len(label):
            length_error +=1

        f.write(f'Original output: {original}\n')
        if correctly_predicted:
            f.write(emoji.emojize(f'{pred} | {label}  :check_mark_button: \n'))
        else:
            f.write(emoji.emojize(f'{pred} | {label}  :cross_mark: \n'))

        f.write('---------------------------------------------------------------------------------- \n\n')
    f.seek(0)
    f.write(f'Dataset: {args.dataset_path}\n')
    f.write(f'Number of Examples {num_examples}\n')
    print(f'Number of Examples {num_examples}')
    f.write(f'ACCURACY:  { float (correct / num_examples)}\n')
    print(f'ACCURACY:  { float (correct / num_examples)}')
    f.write(f'Length error:  { float ((length_error / num_examples) )}\n')
    print(f'Length error:  { float ((length_error / num_examples) )}')
    f.write('----------------------------------------------------- \n\n')

    # for output in output_text:
    #     f.write(output)
    #     f.write('\n\n')


100%|██████████| 4/4 [01:11<00:00, 17.86s/it]

300 300
Number of Examples 100
ACCURACY:  0.11
Length error:  0.23



