In [None]:
# default_exp data.dataset
%load_ext autoreload
%autoreload 2

In [None]:
# export
# hide
from typing import List, Union, Any
from pathlib import Path
from pymemri.data.itembase import Item, EdgeList
from pymemri.exporters.exporters import Query
from pymemri.data import _central_schema

In [None]:
# hide
from nbdev import show_doc

# Datasets

A dataset is a central item in the pod that organizes your project data and label annotations. To facilitate using `Dataset` items in your datascience workflow, the `Dataset` class contains methods to convert the data to a popular datascience format, or save a dataset to disk. 

In [None]:
# export
# hide
def filter_rows(dataset: dict, filter_val=None) -> dict:
    missing_idx = set()
    for column in dataset.values():
        missing_idx.update([i for i, val in enumerate(column) if val == filter_val])
    return {
        k: [item for i, item in enumerate(v) if i not in missing_idx]
        for k, v in dataset.items()
    }

In [None]:
# export
class Dataset(_central_schema.Dataset):
    """
    The main Dataset class
    """
    requires_client_ref = True

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._client = None

    def _get_items(self):
        if self._client is None:
            raise ValueError("Dataset does not have associated PodClient.")
        if not len(self.entry):
            edges = self._client.get_edges(self.id)
            for e in self._client.get_edges(self.id):
                self.add_edge(e["name"], e["item"])

        return self.entry

    def _get_data(self, dtype: str, columns: List[str], filter_missing: bool = True):
        if self._client is None:
            raise ValueError("Dataset does not have associated PodClient.")
        items = self._get_items()

        query = Query("id", *columns)
        result = query.execute(self._client, items)
        if filter_missing:
            result = filter_rows(result, filter_val=None)
        return query.convert_dtype(result, dtype)

    def to(self, dtype: str, columns: List[str], filter_missing: bool = True):
        """
        Converts Dataset to a different format.

        Available formats:
        list: a 2-dimensional list, containing one dataset entry per row
        dict: a list of dicts, where each dict contains {column: value} for each column
        pd: a Pandas dataframe


        Args:
            dtype (str): Datatype of the returned dataset
            columns (List[str]): Column names of the dataset
            filter_missing (bool, optional): If true, all rows that contain `None` values are omitted.
                Defaults to True.

        Returns:
            Any: Dataset formatted according to `dtype`
        """
        return self._get_data(dtype, columns, filter_missing)

    def save(
        self, path: Union[Path, str], columns: List[str], filter_missing: bool = True
    ):
        """
        Save dataset to CSV.
        """
        result = self._get_data("pandas", columns, filter_missing)
        result.to_csv(path, index=False)

In [None]:
show_doc(Dataset.to)

show_doc(Dataset.save)

In [None]:
# hide
from pymemri.pod.client import PodClient
from pymemri.data.schema import Account, Person, Message, CategoricalLabel, DatasetEntry
from pymemri.data.itembase import Edge
import random
import tempfile
import pandas as pd

## Usage

To convert the data in the pod to a different format, `Dataset` implements the `Dataset.to` method. In the `columns` argument, you can define which features will be included in your dataset. A `column` is either a property of an entry in the dataset, or a property of an item connected to an entry in the dataset.

The Pod uses the following schema for Dataset items. Note that the `DatasetEntry` item is always included, and the actual data can be found by traversing the `entry.data` Edge.

![dataset schema](images/dataset-diagram.png)

Now for example, if a dataset is a set of `Message` items, and the content has to be included as column, `data.content` would be the column name. If the name of the `sender` of a message has to be included, `data.sender.handle` would be a valid column name.

The following example retrieves an example dataset of `Message` items, and formats them to a Pandas dataframe:

In [None]:
client = PodClient()
client.add_to_schema(Dataset, DatasetEntry)

In [None]:
# hide
client.add_to_schema(Account, Person, Message, CategoricalLabel, Dataset, DatasetEntry)

dataset = Dataset(name="example-dataset")

num_items = 10
messages = []
items = [dataset]
edges = []
for i in range(num_items):
    entry = DatasetEntry()
    msg = Message(content=f"content_{i}", service="my_service")
    account = Account(handle=f"account_{i}")
    person = Person(firstName=f"firstname_{i}")
    label = CategoricalLabel(labelValue=f"label_{i}")
    items.extend([entry, msg, account, person, label])
    edges.extend([
        Edge(dataset, entry, "entry"),
        Edge(entry, msg, "data"),
        Edge(msg, account, "sender"),
        Edge(entry, label, "annotation"),
        Edge(account, person, "owner")
    ])
    messages.append(msg)

client.bulk_action(
    create_items=items,
    create_edges=edges
)

In [None]:
dataset = client.get_dataset("example-dataset")

columns = ["data.content", "data.sender.handle", "annotation.labelValue"]
dataframe = dataset.to("pd", columns=columns)
dataframe.head()

In [None]:
# hide
columns = ["data.content", "data.sender.owner.firstName", "annotation.labelValue"]
dataframe = dataset.to("pd", columns=columns)
dataframe.head()

assert isinstance(dataframe, pd.DataFrame)
assert all(dataframe.columns == ["id"] + columns)
assert len(dataframe) == num_items
dataframe.head()

In [None]:
# hide
# TODO tempfile does not work in CI
# with tempfile.TemporaryFile(mode='w+') as f:
#     dataset.save(f, columns=["content", "sender.owner.firstName", "label.name"])
#     f.seek(0)
#     result = pd.read_csv(f)
    
# assert result.equals(dataframe)

In [None]:
# hide
from nbdev.export import *
notebook2script()