In [1]:
import re
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer


In [2]:
dataset_path = '../data/disjoint_word_init.json'
model_name = 'meta-llama/Llama-2-7b-hf'
# train_dataset = load_dataset('json', data_files=dataset_path, field="train",split="train")
# val_dataset = load_dataset('json', data_files=dataset_path, field="val",split="train")
# test_dataset = load_dataset('json', data_files=dataset_path, field="test",split="train")



Dataset({
    features: ['unique_clue_id', 'soln_with_spaces', 'across_or_down', 'clue', 'number', 'creator', 'lengths', 'soln', 'dataset', 'id', 'lengths_punctuation', 'type', 'idx', 'pos', 'orig_lengths'],
    num_rows: 75847
})

In [3]:

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

DEFAULT_SYSTEM_PROMPT = """
Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of characters in the answer should be same as the number in the parenthesis. Just output the answer only.
""".strip()

def generate_prompt(example, prompt_head, is_train, field='prompt'):


    augmented_clue= f'{example["clue"]} ({example["orig_lengths"]})'

    example['clue'] = augmented_clue
    solution = example['soln_with_spaces']

    
    ## For training, we need to provide the system prompt, the idea and the story
    if is_train:
        example[field] =  f"""
### Instruction: {prompt_head}

### Input:
{augmented_clue.strip()}

### Response:
{solution}
""".strip()
    
    ## For validation and testing, we only need to provide the idea
    else:
        example[field] = f"""
### Instruction: {prompt_head}

### Input:
{augmented_clue.strip()}
""".strip()
        
    return example

def get_dataset(path,split = 'train',tokenizer = tokenizer , field='prompt', prompt_head = DEFAULT_SYSTEM_PROMPT):

    dataset = load_dataset('json', data_files=path, field=split , split='train')

    if split == 'train':
        dataset = dataset.map(generate_prompt , fn_kwargs={"field": field, "prompt_head": prompt_head, "is_train": True})
    
    else:
        dataset = dataset.map(generate_prompt , fn_kwargs={"field": field, "prompt_head": prompt_head, "is_train": False})
        
    dataset = dataset.select_columns([field,'clue', 'soln_with_spaces'])


    dataset = dataset.map(lambda x: tokenizer(x[field], padding=True, truncation=True), batched=True)
    dataset = dataset.rename_column('soln_with_spaces', 'labels')

    return dataset

In [4]:

train = get_dataset(dataset_path, split='val')

Map:   0%|          | 0/32628 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [6]:
print(train[0])

{'prompt': '### Instruction: Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of characters in the answer should be same as the number in the parenthesis. Just output the answer only.\n\n### Input:\nTension in an arm? Slightly (1,6)', 'clue': 'Tension in an arm? Slightly (1,6)', 'labels': 'a trifle', 'input_ids': [1, 835, 2799, 4080, 29901, 13866, 338, 263, 23960, 363, 263, 1602, 4641, 292, 4891, 1742, 29889, 3575, 3414, 338, 304, 4505, 445, 23960, 29889, 450, 1353, 310, 4890, 297, 278, 1234, 881, 367, 1021, 408, 278, 1353, 297, 278, 3847, 29882, 6656, 29889, 3387, 1962, 278, 1234, 871, 29889, 13, 13, 2277, 29937, 10567, 29901, 13, 29911, 2673, 297, 385, 5075, 29973, 317, 4366, 368, 313, 29896, 29892, 29953, 29897, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [9]:
val = get_dataset(dataset_path, split='val')

Map:   0%|          | 0/32628 [00:00<?, ? examples/s]

In [10]:
print(val[0]['prompt'])


### Instruction: Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of characters in the answer should be same as the number in the parenthesis. Just output the answer only.

### Input:
Tension in an arm? Slightly (1,6)
