In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import notebook_login
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from datasets import load_from_disk
import numpy as np
from peft import PeftModel    


MODEL_NAME = "mistralai/Mistral-7B-v0.1"
adapters_name  = "experiments/Mistral-7B-v0.1/checkpoint-8000"


model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    return_dict=True,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
     
tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)

model = PeftModel.from_pretrained(model, adapters_name)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:

def gen(model, text: str, tmp=.9, top_p=0.9, max_length=16,repetition_penalty=1.5):
    inputs = tokenizer(text, return_tensors="pt").to('cuda')
    inputs_length = len(inputs["input_ids"][0])
    with torch.inference_mode():
        outputs = model.generate(**inputs,
                                max_new_tokens=max_length,
                                do_sample=True,
                                top_p=top_p,
                                temperature=tmp,
                                repetition_penalty=repetition_penalty
                                )

    return tokenizer.decode(outputs[0][inputs_length:], skip_special_tokens=True)

In [3]:
val_dataset = load_from_disk('data/unique_targets')
val_dataset = val_dataset['test']

In [14]:
base_text = '''
I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:
'''


# text = 'Desk register taken no further than Ozzie? (7)'

idx= np.random.randint(0,len(val_dataset),1)

text = val_dataset.select(idx)['clue'][0]
soln = val_dataset.select(idx)['labels'][0]

In [16]:
with open('outputs/mistral_unique_test.txt', 'a') as f:
    for i in range(3):
        print('-'*20)
        # text = input("Enter a prompt: ")
        print('-'*20)

        text = base_text + text

        # text = 'who are you?'
        
        f.write('='*30 + '\n')
        f.write(text + '\n')
        f.write('-'*20 + '\n')

        print(text,f'   soln: {soln}')
        temps = [0.00001,.1,0.9]
        top_ps = [0.9,0.5,0.2]
    
        output = gen(model, text, tmp=temps[i], top_p=top_ps[i], max_length=16,repetition_penalty=1.5)

        output = output.replace('<newline>', '\n')

        f.write(output + '\n')
        f.write('='*30 + '\n')
        f.flush()

        print('-'*20)
        print(output)
        
        print('-'*20)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


--------------------
--------------------

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


--------------------


### Response:
waisty? That's wrong! I
--------------------
--------------------
--------------------

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this 

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


--------------------


### Response:
waisty? That's wrong! I
--------------------
--------------------
--------------------

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this clue. The number of charachters in the answer should be same as the numbers in the parenthesis in the clue.

### Input:

I will give you a cryptic crossword which is a crossword puzzle in which each clue is a word puzzle. Your task is to solve this 