# Split the dataset

In [1]:
import pandas as pd
from datasets import (
    Features, 
    ClassLabel, 
    Value,
    Dataset,
    DatasetDict,
)

In [2]:
DATA_DIRECTORY = "data/"

In [5]:
table_data = pd.read_csv(f"{DATA_DIRECTORY}data_all.tsv", sep='\t', na_values="None", dtype={'gene1':'str','gene2':'str','variant1':'str','variant2':'str', 'sentence':'str'})
table_data = table_data.dropna()
table_data = table_data.drop(table_data.loc[table_data.label == -1].index)
#table_data = table_data.drop('num_tokens', axis=1)
table_data.label = table_data.label.astype('int32')
table_data.pmcid = table_data.pmcid.astype('int32')
features = Features({'sentence': Value('string'), 'pmcid':Value('int32'), 'gene1':Value('string'),
                     'gene2':Value('string'), 'variant1':Value('string'), 'variant2':Value('string'), 'label':ClassLabel(names=[0,1])})
data = Dataset.from_pandas(table_data.drop('num_tokens', axis=1),features=features, preserve_index=True)

In [6]:
len(data['pmcid'])

8442

In [7]:
split = data.train_test_split(test_size=0.2, shuffle=True, stratify_by_column='label', seed=13)

In [8]:
split_train = split['train']

In [9]:
train_dev = split_train.train_test_split(test_size=200, shuffle=True, stratify_by_column='label', seed=13)

In [10]:
dataset = DatasetDict({
    'train' : train_dev['train'],
    'validation' : train_dev['test'],
    'test' : split['test']
})

In [11]:
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'pmcid', 'gene1', 'gene2', 'variant1', 'variant2', 'label'],
        num_rows: 6553
    })
    validation: Dataset({
        features: ['sentence', 'pmcid', 'gene1', 'gene2', 'variant1', 'variant2', 'label'],
        num_rows: 200
    })
    test: Dataset({
        features: ['sentence', 'pmcid', 'gene1', 'gene2', 'variant1', 'variant2', 'label'],
        num_rows: 1689
    })
})

In [12]:
def stats(data):
    positive_fraction = len(data.filter(lambda x: x['label']==1))/len(data)
    number_articles = len(set(data['pmcid']))
    number_positive_articles = len(set(data.filter(lambda x: x['label']==1)['pmcid']))
    number_negative_articles = len(set(data.filter(lambda x: x['label']==0)['pmcid']))
    number_positive = len(data.filter(lambda x: x['label']==1))
    print(f"Total number of examples : {len(data)}")
    print(f"Number of positive examples : {number_positive}")
    print(f"Fraction of positive examples : {positive_fraction}")
    print(f"Total number of articles : {number_articles}")
    print(f"Numer of articles for positive examples : {number_positive_articles}")
    print(f"Numer of articles for negative examples : {number_negative_articles}")

for split in ['test', 'validation', 'train']:
    print(f"Stats for the {split} split")
    stats(dataset[split])

Stats for the test split


  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

Total number of examples : 1689
Number of positive examples : 159
Fraction of positive examples : 0.0941385435168739
Total number of articles : 75
Numer of articles for positive examples : 51
Numer of articles for negative examples : 73
Stats for the validation split


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Total number of examples : 200
Number of positive examples : 19
Fraction of positive examples : 0.095
Total number of articles : 51
Numer of articles for positive examples : 12
Numer of articles for negative examples : 50
Stats for the train split


  0%|          | 0/7 [00:00<?, ?ba/s]

  0%|          | 0/7 [00:00<?, ?ba/s]

  0%|          | 0/7 [00:00<?, ?ba/s]

  0%|          | 0/7 [00:00<?, ?ba/s]

Total number of examples : 6553
Number of positive examples : 616
Fraction of positive examples : 0.09400274683351137
Total number of articles : 79
Numer of articles for positive examples : 61
Numer of articles for negative examples : 78


In [13]:
stats(data)

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

Total number of examples : 8442
Number of positive examples : 794
Fraction of positive examples : 0.09405354181473584
Total number of articles : 81
Numer of articles for positive examples : 64
Numer of articles for negative examples : 79


In [None]:
data_dir = "../../data/"
for split in ['test', 'train', 'validation']:
    dataset[split].to_csv(f"{data_dir}{split}.csv", na_rep='None',index=False)