In [None]:
#!/usr/bin/env python
# coding: utf-8

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import notebook_login
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
import datasets
import transformers
from datasets import load_dataset
from evaluate import load
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader 
from tqdm import tqdm
import emoji
import argparse




# 
# MODEL_NAME = "meta-llama/Llama-2-7b-hf"

def add_args(parser: argparse.ArgumentParser):

    parser.add_argument('--model_name',
                            type=str,
                            default='meta-llama/Llama-2-7b-hf')

    parser.add_argument('--save_file',
                                type=str,
                                default='pred_output.txt')
    
    parser.add_argument('--batch_size',
                            type=int,
                            default=32)
    
    parser.add_argument('--prompt',
                            type=str,
                            default="""
Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of charachters in the answer should be same as the number in the parenthesis. Just output the answer only. Do not output any explanitions, just the words in the answer.
 
### Input:
Desk register taken no further than Ozzie? (7)

### Output:
rolltop

### Input:
Henry has books stolen (3)

### Output:
hot
""")
    
    parser.add_argument('--num_examples',
                            type=int,
                            default=0)
    



def concat_length(example):

    example["clue"] = f'{example["clue"]} ({example["orig_lengths"]})'

    return example


DEFAULT_SYSTEM_PROMPT = """
Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of charachters in the answer should be same as the number in the parenthesis. Just output the answer only. Do not output any explanitions, just the words in the answer.
 
### Input:
Desk register taken no further than Ozzie? (7)

### Output:
rolltop

### Input:
Henry has books stolen (3)

### Output:
hot
""".strip()


def generate_training_prompt(
    clue: str, prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
    return f"""### Instruction: {prompt}

### Input:
{clue.strip()}

""".strip()
     




def map_prompt(ex, prompt):
    ex['prompt'] =  generate_training_prompt(ex["clue"])

    return ex




def inference(prompts, tokenizer, generation_config, model):
    
   
    encoding = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **encoding,
            max_new_tokens=64,
            temperature=0.00001,
            pad_token_id=tokenizer.eos_token_id,
            generation_config=generation_config,
        )  

    answer_tokens = outputs[:, encoding.input_ids.shape[1] :]
    return answer_tokens
        





In [2]:
from datasets import load_dataset

val_dataset = load_dataset('json', data_files="../data/naive_random.json", field="val",split="train")

In [9]:
import numpy as np

idx= np.random.randint(0,100,10)

x= val_dataset.select(idx)['clue']

In [10]:
x

["They're shaken, swapping tips on communism for city",
 'Photo developed thanks to delicate matter',
 "What's missing in a fight is obscure",
 'Photo developed thanks to delicate matter',
 "Derive support from article in house that's new",
 'Targets given by teachers for work',
 'How one comes to confess',
 'Prior to support leader in Eucharist',
 'Ken exits wounded, holding up TV award, though showing good balance',
 "Aphrodisiac - it's in the heart, rising"]