In [1]:
from datasets import load_dataset, Dataset, DatasetDict

In [2]:
dataset = load_dataset("griffin/chain_of_density", "unannotated")

In [3]:
from pprint import pprint
pprint(dataset["train"][0])

{'article': 'A hiker was arrested and warned she could face jail after freeing '
            'an eagle from a trap and springing three more traps to protect '
            'other animals. Kathleen Adair, 39, was walking her three dogs up '
            'Davies Creek Trail in Alaska on Christmas Eve when she spotted '
            'the bird with each leg shut inside traps. She spent an hour '
            'freeing the creature before alerting a bird rescue firm. Heading '
            'home, she also sprung another trap which she spotted in the '
            'ground - prompting an investigation by Alaska Wildlife Troopers '
            'that landed her in court. Eventually tracked down by authorities '
            'she was charged and hauled to court facing a $500 fine and 30 '
            'days in jail. Arrested: Kathleen Adair, 39, was charged with '
            'hindering lawful trapping after snaring three traps in Alaska . '
            'The eagle was found and euthanized three days aft

In [4]:
columns_to_keep = {"article", "missing", "num_entities", "prediction"}
columns_to_remove = list(set(dataset["train"].column_names) - columns_to_keep)

dataset = dataset.remove_columns(columns_to_remove)

In [5]:
# Some rows dont contain all the information, so we remove them
dataset = dataset.filter(lambda x: (
    len(x["prediction"]) == len(x["num_entities"]) == len(x["missing"]) 
))
len(dataset["train"])

Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

768

In [6]:
# Remove summaries that dont contain many entities
import numpy as np

def remove_short_summaries(example):
    num_entities = example["num_entities"]
    median_num_entities = np.median(num_entities)

    if all(num >= median_num_entities for num in num_entities):
        return example

    # else:
    missing = example["missing"]
    prediction = example["prediction"]

    # Get indices of summaries that dont have enough entities
    remove_indices = [
        i for i, num in enumerate(num_entities)
        if num < median_num_entities
    ]

    # Remove summaries that dont have enough entities
    num_entities = np.delete(num_entities, remove_indices, axis=0)
    prediction = np.delete(prediction, remove_indices, axis=0)
    missing = np.delete(missing, remove_indices, axis=0)
    
    fixed = dict(
        missing=missing,
        num_entities=num_entities,
        prediction=prediction,
    )

    return {
        **example,
        **fixed,
    }

In [7]:
dataset = dataset.map(remove_short_summaries)

Map:   0%|          | 0/768 [00:00<?, ? examples/s]

In [8]:
dataset["train"][0]

{'article': "Customers could soon design their own items, go into a supermarket and have them printed in 3D, if an ambitious major project by Tesco succeeds. The high street retail giant is working on developing new technology for a variety of products in its stores. Ideas include digitally making clothing, furniture, personal gifts and even food in their shops. The supermarket giant is keen to use new technologies to offer a wider range of products to consumers. Company researchers believe 3D printers are a natural progression given that they already offer photo and poster printing . The project could also see Tesco stores repair broken items or print spare parts for a product that has already been purchased. Paul Wilkinson, a lead research specialist with Tesco, revealed the retail giant’s ambitions writing a blog post on tesco.com about the potential of 3D printing. He said: '3D printing] could revolutionise the way we view stores and what we can get from them.' Wilkinson, who is he

In [9]:
make_chatml = lambda name, role, content: dict(
    name=name, role=role, content=content,
)

system = lambda name, content: make_chatml(
    role="system",
    name=name,
    content=content,
)

situation = lambda content: system(name="situation", content=content)
thought = lambda content: system(name="thought", content=content)
information = lambda content: system(name="information", content=content)
me = lambda content, name=None: make_chatml(
    role="assistant",
    content=content,
    name=name,
)

person = lambda content, name=None: make_chatml(
    role="user",
    content=content,
    name=name,
)


In [10]:
def to_chatml(row):
    article = row["article"]
    prediction = row["prediction"]
    missing = row["missing"]
    newline = "\n"
    
    # Turn into chatml
    chatml = [
        situation("A user is talking to an AI assistant. They are discussing an article that the user has read."),
        person(f"I just read the article below. Can you summarize it?", name="User"),
        information(f"Article:{newline}{article}"),
    ]

    # Add summaries
    [first_summary, *summaries] = prediction
    summary_message = f"Summary:{newline}{first_summary}"
    chatml.append(me(summary_message, name="AI Assistant"))

    for missing_entity, next_summary in zip(missing, summaries):
        # ask if anything is missing
        missing_question = f"Is there anything missing from this summary that was covered in the article?"
        chatml.append(person(missing_question, name="User"))

        missing_answer = (
            f'Actually yes, the article also mentioned "{missing_entity}".'
            f" Let me rewrite the summary and add that."
            f"{newline}{newline}"
            f"Summary:{newline}{next_summary}"
        )

        chatml.append(me(missing_answer, name="AI Assistant"))

    return dict(chatml=chatml)


In [11]:
dataset = dataset.map(to_chatml)

Map:   0%|          | 0/768 [00:00<?, ? examples/s]

In [12]:
dataset.push_to_hub("diwank/chain_of_density-chatml", private=True)

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/718 [00:00<?, ?B/s]