# Stats about Train / Dev / Test sets

In [1]:
import pandas as pd

import sys
sys.path.insert(0, '../')
from utils.config import PATHS

## Load data

In [2]:
datapath = PATHS.getpath('data_expr_sept')
train = pd.read_pickle(datapath / 'clf_domains/train.pkl')
test = pd.read_pickle(datapath / 'clf_domains/test.pkl')
dev = pd.read_pickle(datapath / 'clf_domains/dev.pkl')

In [3]:
domains=['ADM', 'ATT', 'BER', 'ENR', 'ETN', 'FAC', 'INS', 'MBW', 'STM']

## Check correct split

i.e. there are no notes that appear in more than one set

In [4]:
print(test.NotitieID.isin(train.NotitieID).any())
print(dev.NotitieID.isin(train.NotitieID).any())
print(dev.NotitieID.isin(test.NotitieID).any())

False
False
False


## Number of sentences per domain

- A sentence can contain more than one domain and therefore be counted more than once.
- The last column is the total number of sentences in the dataset (incl. all negative examples)

In [5]:
data = pd.concat([
    train.assign(dataset = 'train'),
    test.assign(dataset = 'test'),
    dev.assign(dataset = 'dev'),
])

balance = pd.DataFrame(
    index = pd.MultiIndex.from_frame(data[['dataset', 'pad_sen_id']]),
    columns = domains,
    data = data.labels.to_list()
)

dataset_sizes = balance.pivot_table(
    index='dataset',
    aggfunc='size',
).rename('n_sentences')

piv = balance.pivot_table(
    index='dataset',
    aggfunc='sum',
).join(dataset_sizes)
piv.loc['total'] = piv.sum()

piv

Unnamed: 0_level_0,ADM,ATT,BER,ENR,ETN,FAC,INS,MBW,STM,n_sentences
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dev,411,22,29,105,225,119,127,96,147,21742
test,405,32,26,98,178,136,133,64,155,22082
train,4988,247,213,989,2420,1063,1067,755,1416,242291
total,5804,301,268,1192,2823,1318,1327,915,1718,286115


## % of domain sentences out of total 

In [6]:
piv.div(piv.n_sentences, axis=0).mul(100).round(2)

Unnamed: 0_level_0,ADM,ATT,BER,ENR,ETN,FAC,INS,MBW,STM,n_sentences
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dev,1.89,0.1,0.13,0.48,1.03,0.55,0.58,0.44,0.68,100.0
test,1.83,0.14,0.12,0.44,0.81,0.62,0.6,0.29,0.7,100.0
train,2.06,0.1,0.09,0.41,1.0,0.44,0.44,0.31,0.58,100.0
total,2.03,0.11,0.09,0.42,0.99,0.46,0.46,0.32,0.6,100.0
