In [1]:
from datasets import load_dataset, Dataset, DatasetDict

In [2]:
dataset = load_dataset("allenai/soda")

Downloading readme:   0%|          | 0.00/4.92k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/689M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/82.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/84.2M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [3]:
from pprint import pprint
pprint(dataset["train"][0])

{'PersonX': 'Veda',
 'PersonY': '',
 'PersonZ': '',
 'dialogue': ["Hi, Father. I'm Veda. I'm new to the area and was curious about "
              'your church. Could you tell me a little bit about it?',
              'Of course, Veda. Our church is based on the teachings of Jesus '
              'Christ. We believe in loving our neighbor and treating others '
              'as we would want to be treated. We strive to live according to '
              "Christ's example and teachings.",
              'That sounds like a really great way to live. I can see why so '
              'many people are drawn to this religion. What do you think makes '
              'Christianity different from other religions?',
              'Well, there are a lot of different interpretations of '
              "Christianity, but for us, it's all about following Jesus "
              "Christ's example. He was a man who loved unconditionally and "
              'forgave easily. He preached compassion and under

In [4]:
# Filter examples with inputs
dataset = dataset.filter(lambda x: all(
    x[c].lower() == "yes"
    for c in [
        "head_answer",
        "relation_tail_answer",
        "pmi_head_answer",
        "pmi_relation_tail_answer",
    ]
))

len(dataset["train"])

Filter:   0%|          | 0/2383164 [00:00<?, ? examples/s]

Filter:   0%|          | 0/292692 [00:00<?, ? examples/s]

Filter:   0%|          | 0/297936 [00:00<?, ? examples/s]

1523126

In [5]:
all_relations = set(dataset["train"]["relation"])
all_relations

{'xAttr', 'xEffect', 'xIntent', 'xNeed', 'xReact', 'xWant'}

In [11]:
relation_map = dict(
    xAttr=lambda x: f"What does this imply about {x}?",
    xIntent=lambda x: f"What does {x} intend to do?",
    xNeed=lambda x: f"What does this tell you about {x}'s needs?",
    xReact=lambda x: f"How will {x} react?",
    xEffect=lambda x: f"What effect will this have on {x}?",
    xWant=lambda x: f"What does {x} want?",
)

In [7]:
num_items = 10_000
ds = dataset["train"].to_pandas().groupby("relation").apply(lambda x: x.head(num_items)).reset_index(drop=True)

In [8]:
dataset = Dataset.from_pandas(ds)

In [9]:
make_chatml = lambda name, role, content: dict(
    name=name, role=role, content=content,
)

system = lambda name, content: make_chatml(
    role="system",
    name=name,
    content=content,
)

situation = lambda content: system(name="situation", content=content)
thought = lambda content: system(name="thought", content=content)
information = lambda content: system(name="information", content=content)
me = lambda content, name=None: make_chatml(
    role="assistant",
    content=content,
    name=name,
)

person = lambda content, name=None: make_chatml(
    role="user",
    content=content,
    name=name,
)


In [10]:
dataset[0]

{'head': "PersonX listens to PersonY's thoughts",
 'relation': 'xAttr',
 'tail': 'a good friend',
 'literal': "Rylea is a good friend. Rylea listens to Shavon's thoughts.",
 'narrative': 'Rylea sat down with Shavon and asked him what was wrong. Shavon told Rylea that he was having a hard time and needed someone to talk to. Rylea said that he would be happy to be there for his friend.',
 'dialogue': ["Hey Shavon, what's up? You seem troubled.",
  "Yeah, I am. I'm just having a hard time and needed someone to talk to.",
  "Of course, man. I'm always here for you. What's going on?",
  "It's just everything. Work is stressing me out, my relationship is falling apart, and I feel like I'm losing touch with my friends. I don't know what to do.",
  "Well, let's start with work then. What's going on there?",
  "It's just that everything is so demanding and I can't keep up. I'm constantly behind and it feels like I'm never going to catch up.",
  "Okay, that does sound pretty tough. But it sounds

In [12]:
def to_chatml(row):
    speakers = row["speakers"]
    dialogue = row["dialogue"]
    narrative = row["narrative"]
    literal_description = row["literal"]
    person_x = row["PersonX"]
    relation = row["relation"]
    question = relation_map[relation](person_x)

    newline = "\n"

    [*spkrs, last_spkr] = speakers
    speakers_formatted = ", ".join(spkrs) + f" and {last_spkr}"
   
    # System message
    system_message = (
        f"An AI is analyzing a dialog happening between {speakers_formatted}."
        " It then reflects on the conversation."
    )

    # Thoughts
    thought_message1 = (
        f"What is happening in this dialog?{newline}{newline}{narrative}"
    )

    thought_message2 = (
        f"{question}{newline}{newline}{literal_description}"
    )

    dialog_chatml = [
        person(text, name=speaker)
        for speaker, text in zip(speakers, dialogue)
    ]

    # Turn into chatml
    chatml = [
        situation(system_message),
        *dialog_chatml,
        thought(thought_message1),
        thought(thought_message2),
    ]

    return dict(chatml=chatml)


In [13]:
dataset = dataset.map(to_chatml)

Map:   0%|          | 0/60000 [00:00<?, ? examples/s]

In [16]:
columns_to_remove = set(dataset.column_names) - {"chatml"}
dataset = dataset.remove_columns(list(columns_to_remove))

In [17]:
dataset.push_to_hub("diwank/soda-chatml", private=True)

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/60 [00:00<?, ?ba/s]