# Requirements and Initialization

In [None]:
import pandas as pd
import numpy as np

from pathlib from Path

In [None]:
root = Path('..')

# Split and Save Datasets

In [None]:
def train_test_dev_split(dataset, test_ratio = 0.1, dev_ratio = 0.1):
  train_ratio = 1 - test_ratio - dev_ratio
  train, dev, test = np.split(dataset.sample(frac=1), [int(train_ratio * len(dataset)), int((train_ratio + dev_ratio) * len(dataset))])
  return train, dev, test

In [None]:
dataset_name = 'ost' # tat, ted
metric_df = pd.read_csv(root / 'datasets'  / f'{dataset_name}-sim.csv')

In [None]:
metric_df['concat'] = metric_df['src'] + metric_df['tgt']
metric_df = metric_df.drop_duplicates(subset=['concat'])
metric_df = metric_df.drop(labels = 'concat', axis = 1).reset_index(drop=True)

In [None]:
sim_alias = 'emrecan'
threshold = 0.85
step_size = .02

In [None]:
# split the dataset to parts to better represent the distribution of data
# the data is split by this code in each score range (e.g. (0.85-0.87], (0.87-0.89])
num_splits = int((1 - threshold) / step_size)
filtered_dfs = [metric_df[(metric_df[sim_alias] > threshold + step_size * i) & (metric_df[sim_alias] <= threshold + step_size * (i + 1))] for i in range(num_splits)]

In [None]:
filtered_dfs = [train_test_dev_split(df, test_ratio = .05, dev_ratio = .05) for df in filtered_dfs]

In [None]:
train = pd.concat([dfs[0] for dfs in filtered_dfs])
dev = pd.concat([dfs[1] for dfs in filtered_dfs])
test = pd.concat([dfs[2] for dfs in filtered_dfs])

# shuffle datasets
train = train.sample(frac = 1).reset_index(drop=True)
dev = dev.sample(frac = 1).reset_index(drop=True)
test = test.sample(frac = 1).reset_index(drop=True)

In [None]:
train.to_csv(root / 'datasets' / f'{dataset_name}-train.csv', index=False)
train.to_csv(root / 'datasets' / f'{dataset_name}-dev.csv', index=False)
train.to_csv(root / 'datasets' / f'{dataset_name}-test.csv', index=False)

In [None]:
filtered_len = len(metric_df[metric_df[sim_alias] > threshold])
print(len(train) / filtered_len)
print(len(dev) / filtered_len)
print(len(test) / filtered_len)