In [None]:
from datasets import load_dataset

## Load Dataset
dataset = load_dataset("RiTA-nlp/ITALIC", "hard_speaker")
ds_train = dataset["train"]
ds_validation = dataset["validation"]

## Mapping intents to labels
intents = set(ds_train['intent'])
label2id, id2label = dict(), dict()
for i, label in enumerate(intents):
    label2id[label] = str(i)
    id2label[str(i)] = label
num_labels = len(id2label)



In [None]:
import pandas as pd
df_train = pd.DataFrame(ds_train)
df_validation = pd.DataFrame(ds_validation)
df_train

In [None]:
df_validation

In [None]:
speaker_ids = df_train["speaker_id"].unique()

# remove from val speakers that are in train
print(len(df_validation))
df_validation = df_validation[~df_validation["speaker_id"].isin(speaker_ids)]
print(len(df_validation))

In [None]:
df_train.intent.value_counts().plot(kind='bar')

In [None]:
speakerids = df_train['speaker_id'].value_counts()

len(speakerids)

In [None]:
speakerids.plot(kind='bar')

In [None]:
from utils import set_seed
import random

def get_forget_retain_split(df_train, min_samples_forget=100, ratio=0.025, seed=42, speaker_col='speakerId'):

    speakerids = df_train[speaker_col].value_counts()

    set_seed(seed)

    # sample speakers that have at least 200 samples until 2.5% of the total dataset samples are reached
    speakers = speakerids[speakerids>min_samples_forget].index.tolist()
    total_samples = 0 
    speakers_to_sample = []
    while total_samples < len(df_train)*ratio:
        speaker = random.choice(speakers)
        speakers_to_sample.append(speaker)
        total_samples += speakerids[speaker]

    df_forget = df_train[df_train[speaker_col].isin(speakers_to_sample)]
    df_retain = df_train[~df_train[speaker_col].isin(speakers_to_sample)]
    return df_forget, df_retain

speakerl_col = 'speaker_id'
df_forget, df_retain = get_forget_retain_split(df_train, speaker_col=speakerl_col)

assert len(df_forget) + len(df_retain) == len(df_train)
assert len(set(df_forget[speakerl_col]).intersection(set(df_retain[speakerl_col]))) == 0

In [None]:
len(df_forget) / len(df_train)

In [None]:
df_forget.intent.value_counts().plot(kind='bar')

In [None]:
df_retain.intent.value_counts().plot(kind='bar')

In [None]:
# save the indexes in a txt file of the forget samples and the retain one 
forget_indexes = df_forget.index.tolist()
retain_indexes = df_retain.index.tolist()

with open('forget_indexes.txt', 'w') as f:
    for item in forget_indexes:
        f.write("%s\n" % item)

with open('retain_indexes.txt', 'w') as f:
    for item in retain_indexes:
        f.write("%s\n" % item)

In [None]:
def get_forget_retain_datasets(ds_train, data_path):
    with open(data_path + 'forget_indexes.txt') as f:
        forget_indexes = f.readlines()
    forget_indexes = [int(x.strip()) for x in forget_indexes]

    with open(data_path + 'retain_indexes.txt') as f:
        retain_indexes = f.readlines()
    retain_indexes = [int(x.strip()) for x in retain_indexes]

    ds_forget = ds_train.select(forget_indexes)
    ds_retain = ds_train.select(retain_indexes)

    return ds_forget, ds_retain

In [None]:
# split in half validation and test

len_ds_validation = len(ds_validation)

ds_validation_half = ds_validation.shard(num_shards=2, index=0)

ds_test = ds_validation.shard(num_shards=2, index=1)

len(ds_validation_half), len(ds_test)