# Create Datasets

This notebook reads all SweLL-gold source files and converts them into transformers-compatible Datasets.

## Requirements

Ensure you have the SweLL-gold corpus downloaded and a `.env` file in the repository root with a path to the SweLL-gold directory, i.e:
```
SWELL_DIR=<PATH TO SWELL DIRECTORY>
```

## Imports

In [None]:
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from dotenv import load_dotenv
from os import getenv, path, makedirs
from datasets import Dataset
from prompts import minimal_prompt, fluency_prompt

## Load Tokenizer

In [None]:
base_model_name = "LumiOpen/Viking-33B"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

## Define Helper Functions to Read Corpus Files

In [None]:
# Read SweLL-directory path from .env file
load_dotenv()
swell_dir = getenv("SWELL_DIR")


def get_file_paths(file_type):
    """
    Gets the paths to the files containing dev, test, and train split for the given file_type.
    The file_type must be one of: orig (source), ref1 (minimal edit), or ref2 (fluency edit)
    """

    # Make sure file_type is correct
    assert file_type in ["orig", "ref1", "ref2"]

    base_name = f"sv-swell_gold-{file_type}"
    splits = ["dev", "test", "train"]
    return {split: path.join(swell_dir, f"{base_name}-{split}.md") for split in splits}


def md_to_dict(md):
    """
    From the multigec-2025-data-providers repo:
    - https://github.com/spraakbanken/multigec-2025-data-providers

    Parse shared task format into a dictionary where keys are essay IDs
    and values are essay texts.

    Arguments:

    md --- a string with the content of a shared task Markdown file.
    """
    essay_dict = {}
    for essay in md.split("### essay_id = ")[1:]:
        (essay_id, text) = essay.split("\n", maxsplit=1)
        essay_dict[essay_id] = text.strip("\n")
    return essay_dict


def read_file_to_dict(fp):
    """
    Reads the file path into a dictionary, by using the function md_to_dict.

    Arguments:

    fp --- The File path to read.
    """
    with open(fp) as f:
        md = f.read()
    return md_to_dict(md)


def read_files_to_dicts(file_type):
    """
    Reads the corresponding files for file_type (orig, ref1, ref2) into a dictionary with keys dev, test, and train and values {essay_id: essay_text}.
    """
    file_paths = get_file_paths(file_type)
    return {
        split: read_file_to_dict(file_path) for split, file_path in file_paths.items()
    }

## Read Files into Dictionaries

Each file is read into a dictionary on the format:

```python
{split: {essay_id: essay_text}}
```

Where split is in ["dev", "train", "test"]

In [None]:
source_text_dicts = read_files_to_dicts("orig")
minimal_text_dicts = read_files_to_dicts("ref1")
fluency_text_dicts = read_files_to_dicts("ref2")

## Define Dataset-Creation Helper Functions

These helper functions are used to create the Hugging-Face datasets Dataset dictionaries that are used for training.

In [None]:
def create_source_target_dict(src_dict, dst_dict, prompt):
    """
    Create Hugging-Face transformers-compatible source-target dicts.
    First ensures both dicts have equal keys for a correct mapping.
    Then sorts both dicts on essay_id keys and inserts the essay_texts, in the same order, into a dict with the format:

    ```python
    {
        "source": [prompt + text]
        "target": [text]
    }
    ```
    Where the prompt is either the minimal-edit prompt or the fluency-edit prompt.

    Arguments:

    src_dict --- A dict {essay_id: essay_text} containing the source (uncorrected) texts.
    dst_dict --- A dict {essay_id: essay_text} containg the target (corrected) texts.
    prompt --- The prompt to insert before each source text.
    """

    # Ensure both input dictionaries have equal keys
    assert src_dict.keys() == dst_dict.keys()

    # Ensure both lists have same order
    sorted_src = sorted(src_dict.items())
    sorted_dst = sorted(dst_dict.items())

    return {
        "source": [prompt + text for _, text in sorted_src],
        "target": [text for _, text in sorted_dst],
    }


def create_dataset_dict(src_dict, dst_dict, split, prompt):
    """
    Creates a Hugging-Face Dataset dict with the same split from the src_dict and dst_dict.

    Arguments:
    src_dict --- A dict {essay_id: essay_text} containing the source (uncorrected) texts.
    dst_dict --- A dict {essay_id: essay_text} containg the target (corrected) texts.
    split --- The split to create the dict from.
    prompt --- The prompt to insert before each source text.
    """
    return Dataset.from_dict(
        create_source_target_dict(src_dict[split], dst_dict[split], prompt)
    )


def create_dataset(src_dicts, dst_dicts, prompt):
    """
    Creates a Hugging-Face datasets DatasetDict that can be used for training.

    src_dict --- A dict {essay_id: essay_text} containing the source (uncorrected) texts.
    dst_dict --- A dict {essay_id: essay_text} containg the target (corrected) texts.
    split --- The split to create the dict from.
    prompt --- The prompt to insert before each source text.

    """
    return DatasetDict(
        {
            "train": create_dataset_dict(src_dicts, dst_dicts, "train", prompt),
            "validation": create_dataset_dict(src_dicts, dst_dicts, "dev", prompt),
            "test": create_dataset_dict(src_dicts, dst_dicts, "test", prompt),
        }
    )

## Create the datasets

In [None]:
minimal_dataset = create_dataset(source_text_dicts, minimal_text_dicts, minimal_prompt)

fluency_dataset = create_dataset(source_text_dicts, fluency_text_dicts, fluency_prompt)

## Define Tokenization Function

In [None]:
def tokenize_function(dataset):
    return tokenizer(dataset["source"], text_target=dataset["target"])

## Tokenize Both Datasets

This creates tokenized versions of both the minimal-edit dataset and the fluency-edit dataset.

In [None]:
tokenized_minimal_dataset = minimal_dataset.map(tokenize_function)

tokenized_fluency_dataset = fluency_dataset.map(tokenize_function)

In [None]:
# Check the length of each tokenized sequence to ensure they are all within 4096 tokens.

## Save Datasets to Disk

In [None]:
datasets_dir = "datasets"
makedirs(datasets_dir, exist_ok=True)

minimal_dataset_path = path.join(datasets_dir, "minimal")
fluency_dataset_path = path.join(datasets_dir, "fluency")

tokenized_minimal_dataset.save_to_disk(minimal_dataset_path)
tokenized_fluency_dataset.save_to_disk(fluency_dataset_path)