# core

> Fill in a module description here

In [None]:
# | default_exp core

In [None]:
# | export
from datasets import load_dataset
from datasets import ClassLabel
from datasets import IterableDataset, Dataset, DatasetDict, IterableDatasetDict
from typing import Dict, Any, List
from collections import Counter
from tabulate import tabulate
from typing import Union

In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
ds = load_dataset("imdb", streaming=True)
ds

{'train': <datasets.iterable_dataset.IterableDataset>,
 'test': <datasets.iterable_dataset.IterableDataset>,
 'unsupervised': <datasets.iterable_dataset.IterableDataset>}

In [None]:
# | export
def get_label_column_names(features: Dict[str, Any]) -> List[str]:
    return [k for k, v in features.items() if isinstance(v, ClassLabel)]

In [None]:
label_columns = get_label_column_names(ds["train"].features)
assert label_columns
assert isinstance(label_columns, list)

In [None]:
# | export
def yield_label_column(
    dataset: IterableDataset, column_name: str, features: Dict[str, Any]
):
    for row in dataset:
        intlabel = row[column_name]
        if intlabel != -1:
            yield features[column_name].int2str(intlabel)
        else:
            yield "no label"

In [None]:
# | export
def get_label_counts(
    ds: Union[IterableDataset, DatasetDict]
) -> Dict[str, Dict[str, int]]:
    results = {}
    for split_name in ds:
        split = ds[split_name]
        split_features = split.features
        label_columns = get_label_column_names(split_features)
        for column in label_columns:
            labels = yield_label_column(split, column, split_features)
            results[split_name] = dict(Counter(labels))
    return results

In [None]:
results = get_label_counts(ds)
assert results

In [None]:
# | export
def generate_label_breakdown_tables(results):
    tables = []
    for name, split in results.items():
        total = sum(split.values())
        table_data = [(k, v, f"{round((v/total)*100,2)}%") for k, v in split.items()]
        tables.append(
            (
                name,
                tabulate(
                    table_data,
                    tablefmt="github",
                    headers=("Label", "Count", "Percentage"),
                ),
            )
        )
    return tables

In [None]:
results = get_label_counts(ds)
tables = generate_label_breakdown_tables(results)
assert tables

In [None]:
for name, table in tables:
    print(name)
    print(table)

train
| Label   |   Count | Percentage   |
|---------|---------|--------------|
| neg     |   12500 | 50.0%        |
| pos     |   12500 | 50.0%        |
test
| Label   |   Count | Percentage   |
|---------|---------|--------------|
| neg     |   12500 | 50.0%        |
| pos     |   12500 | 50.0%        |
unsupervised
| Label    |   Count | Percentage   |
|----------|---------|--------------|
| no label |   50000 | 100.0%       |


In [None]:
# | hide
import nbdev

nbdev.nbdev_export()