In [20]:
from collections import Counter
from nltk import edit_distance
import pandas as pd
from sklearn.model_selection import train_test_split, ShuffleSplit, GroupShuffleSplit
import datasets
from datasets import Dataset, DatasetDict, load_dataset
from functools import reduce
import numpy as np

import json
import os 
import pandas as pd
from collections import defaultdict
import numpy as np

def process_df(
    df_in,
    aspect_label_encode={
        "Negative":0,
        "Positive":1,
        "unknown":2,
        "no majority": 2,
    },
    sequence_label_encode=None
):
    df = df_in.copy()
    columns_to_keep = [
        'id', 'original_id', 'edit_id', 'is_original', 
        'description', 'review_majority',
        'food_aspect_majority', 'ambiance_aspect_majority', 
        'service_aspect_majority', 'noise_aspect_majority'
    ]
    columns_to_keep += [col for col in df.columns if 'prediction' in col]
    df = df[df["review_majority"]!="no majority"]
    df = df[columns_to_keep].rename(
        columns={
            'description': 'text', 
            'review_majority': 'label',
            'food_aspect_majority': 'food_label',
            'ambiance_aspect_majority': 'ambiance_label',
            'service_aspect_majority': 'service_label',
            'noise_aspect_majority': 'noise_label'
        }
    )
    df = df.replace("", -1).replace(
        {
            "label": sequence_label_encode, 
            "food_label": aspect_label_encode,
            "ambiance_label": aspect_label_encode,
            "service_label": aspect_label_encode,
            "noise_label": aspect_label_encode
        }
    )
    df = df[df["label"]!=-1]
    
    return df

In [2]:
dataset = load_dataset(
    "CEBaB/CEBaB", use_auth_token=True,
    cache_dir="./train_cache/"
)

Using custom data configuration CEBaB--CEBaB-0e2f7ed67c9d7e55
Reusing dataset parquet (./train_cache/CEBaB___parquet/CEBaB--CEBaB-0e2f7ed67c9d7e55/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)


  0%|          | 0/4 [00:00<?, ?it/s]

In [41]:
num_of_class = 5 # 2, 3, 5
condition = "inclusive"

if num_of_class == 2:
    sequence_label_encode = {
        "5": 1,
        "4": 1,
        "3": -1,
        "2": 0,
        "1": 0,
        "no majority": -1, # will be dropped!
    }
elif num_of_class == 3:
    sequence_label_encode = {
        "5": 2,
        "4": 2,
        "3": 1,
        "2": 0,
        "1": 0,
        "no majority": -1, # will be dropped!
    }
elif num_of_class == 5:
    sequence_label_encode = {
            "5": 4,
            "4": 3,
            "3": 2,
            "2": 1,
            "1": 0,
            "no majority": -1, # will be dropped!
        }

In [42]:
train = dataset[f"train_{condition}"].to_pandas()
dev = dataset["validation"].to_pandas()
test = dataset["test"].to_pandas()

In [43]:
post_train = Dataset.from_pandas(process_df(train, sequence_label_encode=sequence_label_encode))
post_dev = Dataset.from_pandas(process_df(dev, sequence_label_encode=sequence_label_encode))
post_test = Dataset.from_pandas(process_df(test, sequence_label_encode=sequence_label_encode))

In [44]:
opentable_seq_cls_dataset = DatasetDict()
opentable_seq_cls_dataset['train'] = post_train
opentable_seq_cls_dataset['validation'] = post_dev
opentable_seq_cls_dataset['test'] = post_test

In [45]:
opentable_seq_cls_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'original_id', 'edit_id', 'is_original', 'text', 'label', 'food_label', 'ambiance_label', 'service_label', 'noise_label', '__index_level_0__'],
        num_rows: 9848
    })
    validation: Dataset({
        features: ['id', 'original_id', 'edit_id', 'is_original', 'text', 'label', 'food_label', 'ambiance_label', 'service_label', 'noise_label', '__index_level_0__'],
        num_rows: 1673
    })
    test: Dataset({
        features: ['id', 'original_id', 'edit_id', 'is_original', 'text', 'label', 'food_label', 'ambiance_label', 'service_label', 'noise_label', '__index_level_0__'],
        num_rows: 1689
    })
})

In [40]:
opentable_seq_cls_dataset.save_to_disk(f"./datasets/Proxy.CEBaB.sa.{num_of_class}-class.{condition}")