In [1]:
from datasets import load_dataset , load_from_disk


In [3]:
DEFAULT_SYSTEM_PROMPT = """
Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of characters in the answer should be same as the number in the parenthesis. Just output the answer only.
""".strip()

def generate_prompt(example, prompt_head, is_train, field='prompt', old_dataset = False):


    augmented_clue = example['clue']
    solution = example['labels']
    
    ## For training, we need to provide the system prompt, the idea and the story
    if is_train:
        example[field] =  f"""
### Instruction: {prompt_head}

### Input:
{augmented_clue.strip()}

### Response:
{solution}
""".strip()
    
    ## For validation and testing, we only need to provide the idea
    else:
        example[field] = f"""
### Instruction: {prompt_head}

### Input:
{augmented_clue.strip()}
""".strip()
        
    return example

In [18]:
def get_dataset(dataset_path,split = 'train', field='prompt', prompt_head = DEFAULT_SYSTEM_PROMPT, old_dataset = False):


    if old_dataset:
        dataset = load_dataset('json', data_files=dataset_path , split='train')
        dataset = dataset.remove_columns(['idx'])
        dataset = dataset.rename_column('target', 'labels')
        dataset = dataset.rename_column('input', 'clue')


    else:
        dataset = load_from_disk(dataset_path)

        assert split in dataset.keys(), f"Split {split} not found in dataset {dataset_path}"

        dataset = dataset[split]
        print('------------------ TRAINING ON UNIQUE CLUES ------------------')



    dataset = dataset.map(generate_prompt ,
                            fn_kwargs={"field": field, "prompt_head": prompt_head, "is_train": split == 'train',"old_dataset": False})



        
    return dataset

In [24]:
ds = get_dataset('data/unique_targets', split = 'train', field='prompt', prompt_head = DEFAULT_SYSTEM_PROMPT, old_dataset = 0)

------------------ TRAINING ON UNIQUE CLUES ------------------


Map:   0%|          | 0/50659 [00:00<?, ? examples/s]

In [25]:
ds

Dataset({
    features: ['labels', 'clue', 'prompt'],
    num_rows: 50659
})

In [26]:
ds['prompt'][0]

'### Instruction: Below is a clue for a decrypting crossword. Your task is to solve this clue. The number of characters in the answer should be same as the number in the parenthesis. Just output the answer only.\n\n### Input:\nHairy thing faced extinction initially, certainly around US city, ending in death (7)\n\n### Response:\neyelash'