# Split Generator

This notebook generates the splits of the CHV for training, development and testing of machine learning algorithms.

We generate a total of five splits:

1. **Random split**: A completely random split;
2. **Stratified split**: A split where we ensure that each ID appearing the dev and test set appears in the training set if possible, i.e.:
  - The model has to link (mostly) *already seen* concepts
  - for each SNOMED ID $x$, we have that $\forall x : x \in (Dev \cup Test) \Rightarrow x \in Train  $;
  - Note that also $Dev \cap Test = \emptyset$
3. **Zero-Shot split**: A split where we put a different set of concepts in each set, i.e.:
  - the model has to link *unseen* concepts
  - $Train \cap Dev \cap Test = \emptyset$.

For the last two of these split is generated for the Generic SNOMED IDs and for the Specific SNOMED IDs.

## Data loading and imports

In [1]:
import os
import pandas as pd
import math

from sklearn.model_selection import train_test_split

In [2]:
INPUT_FILE = './data/chv.csv'
OUTPUT_FOLDER = './data/splits/'
RANDOM_SEED = 1203

In [3]:
all_terms = pd.read_csv(INPUT_FILE, sep='\t', encoding='utf8')
all_terms.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
0,0,5FU,Fluorouracil,387172005,Fluorouracil,387172005,I can ' t do 5FU because it causes cardiotoxic...,https://www.reddit.com/4gs56p,./stage_1/Batch 7-Ready.xlsx
1,1,5FU,Fluorouracil,387172005,Fluorouracil,387172005,He is starting 5FU and Leucovorin and Hercepti...,https://www.reddit.com/9tekw3,./stage_1/Batch 7-Ready.xlsx
2,2,5FU,Fluorouracil,387172005,Fluorouracil,387172005,Around treatment 9 they realized that I was al...,https://www.reddit.com/916p80,./stage_1/Batch 7-Ready.xlsx
3,3,5HTP,Oxitriptan,73916008,Oxitriptan,73916008,"Aside from edibles I take Gaba , Taurine , sub...",https://www.reddit.com/8b7cb4,./stage_1/Batch8-Ready.xlsx
4,4,5HTP,Oxitriptan,73916008,Oxitriptan,73916008,I was recomended the 5HTP and as I said it ini...,https://www.reddit.com/771elt,./stage_1/Batch8-Ready.xlsx


In [4]:
all_terms.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
0,0,5FU,Fluorouracil,387172005,Fluorouracil,387172005,I can ' t do 5FU because it causes cardiotoxic...,https://www.reddit.com/4gs56p,./stage_1/Batch 7-Ready.xlsx
1,1,5FU,Fluorouracil,387172005,Fluorouracil,387172005,He is starting 5FU and Leucovorin and Hercepti...,https://www.reddit.com/9tekw3,./stage_1/Batch 7-Ready.xlsx
2,2,5FU,Fluorouracil,387172005,Fluorouracil,387172005,Around treatment 9 they realized that I was al...,https://www.reddit.com/916p80,./stage_1/Batch 7-Ready.xlsx
3,3,5HTP,Oxitriptan,73916008,Oxitriptan,73916008,"Aside from edibles I take Gaba , Taurine , sub...",https://www.reddit.com/8b7cb4,./stage_1/Batch8-Ready.xlsx
4,4,5HTP,Oxitriptan,73916008,Oxitriptan,73916008,I was recomended the 5HTP and as I said it ini...,https://www.reddit.com/771elt,./stage_1/Batch8-Ready.xlsx


In [5]:
print(f'There are {len(all_terms)} samples.')

There are 20015 samples.


In [6]:
general_ids = all_terms['General SNOMED ID'].unique()
print(general_ids)
print(f'There are {len(general_ids)} General IDs.')

[387172005  73916008 386835005 ...  42841002 228262005 712723002]
There are 3645 General IDs.


In [7]:
specific_ids = all_terms['Specific SNOMED ID'].unique()
print(specific_ids)
print(f'There are {len(specific_ids)} Specific IDs.')

[387172005  73916008 386835005 ...  42841002 228262005 712723002]
There are 4003 Specific IDs.


## Random Split

In [8]:
random_train, random_devtest = train_test_split(
    all_terms,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=0.7
)

In [9]:
random_train.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
14005,14005,misaligned,Misalignment,399898009,Misalignment,399898009,But after an hour of two of being upright it g...,https://www.reddit.com/9no91l,./stage_1/batch 9-Ready.xlsx
12743,12743,large intestine,Structure of large intestine,14742008,Structure of large intestine,14742008,They suspect they will also remove a decent se...,https://www.reddit.com/7abwtn,./stage_1/Batch 7-Ready.xlsx
5405,5405,azelastine,Azelastine,372520005,Azelastine,372520005,"I also use flonase and azelastine together , b...",https://www.reddit.com/9asq2v,./stage_1/Batch 10-Ready.xlsx
13194,13194,liver transplant,Transplantation of liver,18027006,Transplantation of liver,18027006,My 44 year old brother just had his liver tran...,https://www.reddit.com/9gb653,./stage_1/batch 4-Ready.xlsx
17623,17623,spacey,Dizziness,404640003,Dizziness,404640003,I am spacey because I am thinking and daydream...,https://www.reddit.com/9fa8se,./stage_1/Batch16-Ready.xlsx


In [10]:
random_dev, random_test = train_test_split(
    random_devtest,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=1/3
)

### Stats

In [11]:
print(
    f'Random split dataset sizes:\n'
    f'- Training set:    {len(random_train):5.0f}\n'
    f'- Development set: {len(random_dev):5.0f}\n'
    f'- Test set:        {len(random_test):5.0f}')

Random split dataset sizes:
- Training set:    14010
- Development set:  2001
- Test set:         4004


In [12]:
random_train_gids = set(random_train['General SNOMED ID'].values)
random_dev_gids = set(random_dev['General SNOMED ID'].values)
random_test_gids = set(random_test['General SNOMED ID'].values)
random_tt_overlap = 100 * \
    len(random_train_gids.intersection(random_test_gids))/len(general_ids)
print(
    f'There is a {random_tt_overlap:2.2f}% overlap in General IDs between train and test set.')
random_td_overlap = 100 * \
    len(random_train_gids.intersection(random_dev_gids))/len(general_ids)
print(
    f'There is a {random_td_overlap:2.2f}% overlap in General IDs between train and dev set.')
random_dt_overlap = 100 * \
    len(random_test_gids.intersection(random_dev_gids))/len(general_ids)
print(
    f'There is a {random_dt_overlap:2.2f}% overlap in General IDs between test and dev set.')
random_general_overlap = 100 * len(random_train_gids.intersection(
    random_dev_gids.intersection(random_test_gids)))/len(general_ids)
print(
    f'There is a {random_general_overlap:2.2f}% overlap in General IDs between train, dev and test sets.')


There is a 59.64% overlap in General IDs between train and test set.
There is a 37.53% overlap in General IDs between train and dev set.
There is a 24.22% overlap in General IDs between test and dev set.
There is a 23.07% overlap in General IDs between train, dev and test sets.


In [13]:
random_train_sids = set(random_train['Specific SNOMED ID'].values)
random_dev_sids = set(random_dev['Specific SNOMED ID'].values)
random_test_sids = set(random_test['Specific SNOMED ID'].values)
random_tts_overlap = 100 * \
    len(random_train_sids.intersection(random_test_sids))/len(specific_ids)
print(
    f'There is a {random_tts_overlap:2.2f}% overlap in Specific IDs between train and test set.')
random_tds_overlap = 100 * \
    len(random_train_sids.intersection(random_dev_sids))/len(specific_ids)
print(
    f'There is a {random_tds_overlap:2.2f}% overlap in Specific IDs between train and dev set.')
random_dts_overlap = 100 * \
    len(random_test_sids.intersection(random_dev_sids))/len(specific_ids)
print(
    f'There is a {random_dts_overlap:2.2f}% overlap in Specific IDs between test and dev set.')

random_specific_overlap = 100 * len(random_train_sids.intersection(
    random_dev_sids.intersection(random_test_sids)))/len(specific_ids)
print(
    f'There is a {random_specific_overlap:2.2f}% overlap in Specific IDs between train, dev and test sets.')

There is a 54.68% overlap in Specific IDs between train and test set.
There is a 33.97% overlap in Specific IDs between train and dev set.
There is a 21.51% overlap in Specific IDs between test and dev set.
There is a 20.36% overlap in Specific IDs between train, dev and test sets.


In [14]:
print(f'In the random split,\n'
      f'- there are {random_train.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the training set;\n'
      f'- there are {random_dev.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the dev set;\n'
      f'- there are {random_test.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the test set;\n'
      f'- there are {random_train.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the training set;\n'
      f'- there are {random_dev.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the dev set;\n'
      f'- there are {random_test.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the test set;\n'
      )

In the random split,
- there are 3.93 terms per General ID in the training set;
- there are 1.41 terms per General ID in the dev set;
- there are 1.78 terms per General ID in the test set;
- there are 3.66 terms per Specific ID in the training set;
- there are 1.39 terms per Specific ID in the dev set;
- there are 1.72 terms per Specific ID in the test set;



### Save data

In [15]:
random_folder = os.path.join(OUTPUT_FOLDER, 'random')
os.makedirs(random_folder, exist_ok=True)
random_train.to_csv(os.path.join(random_folder, 'train.csv'), sep='\t', encoding='utf8')
random_dev.to_csv(os.path.join(random_folder, 'dev.csv'), sep='\t', encoding='utf8')
random_test.to_csv(os.path.join(random_folder, 'test.csv'), sep='\t', encoding='utf8')

## Stratified Split

The strategy here is:
- For each Snomed IDs with more then 2 samples, we put 1 sample in the training set and 1 sample in the (dev+test) set.
- The remaining samples (which are just a few) are distributed evenly across the sets.

### General IDs

In [16]:
# remove IDs with less than 2 samples
big_generals = all_terms.groupby(['General SNOMED ID']).nunique()
small_generals = big_generals.loc[big_generals.ID < 2].index.values
big_generals = big_generals.loc[big_generals.ID > 1].index.values
# sanity check
assert len(small_generals) + \
    len(big_generals) == len(all_terms['General SNOMED ID'].unique())

print(f'There are {len(small_generals):4.0f} IDs with less than 2 labels and\n'
      f'          {len(big_generals):4.0f} IDs with at least 2.')

There are   10 IDs with less than 2 labels and
          3635 IDs with at least 2.


In [17]:
big_generals = all_terms.loc[all_terms['General SNOMED ID'].isin(big_generals)]

In [18]:
strat_train_g = []
strat_devtest_g = []
for row_ids in big_generals.groupby(['General SNOMED ID']).groups.values():
    strat_devtest_g.append(row_ids.values[0])
    strat_train_g.extend(row_ids.values[1:])

In [19]:
# Find the IDs of the small samples
small_generals = all_terms.loc[all_terms['General SNOMED ID'].isin(
    small_generals)].index.values

In [20]:
first_split = math.ceil(len(small_generals)*.7)
second_split = first_split + math.ceil((len(small_generals) - first_split) / 2)
second_split

# list
small_train = list(small_generals[:first_split])
small_dev = list(small_generals[first_split:second_split])
small_test = list(small_generals[second_split:])
assert len(small_train) + len(small_dev) + \
    len(small_test) == len(small_generals)

In [21]:
# Add most small examples
strat_train_g.extend(small_train)

strat_train_g = all_terms.loc[
    all_terms.index.isin(strat_train_g)]
strat_train_g = strat_train_g.sample(frac=1, random_state=RANDOM_SEED)
strat_train_g.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
10960,10960,head hurts,Headache,25064002,Nasal headache,44538002,My sinuses also feel blocked as well and when ...,https://www.reddit.com/a3a1jk,./stage_1/batch 9-Ready.xlsx
553,553,Blincyto,Blinatumomab,716122004,Blinatumomab,716122004,I ' m currently on day 11 of blincyto .,https://www.reddit.com/4bsiin,./stage_1/batch 15-Ready.xlsx
2518,2518,Meclizine,Meclozine,372879002,Meclozine,372879002,"Meclizine doesn ' t make me sleepy , but you '...",https://www.reddit.com/4xv6yz,./stage_1/batch 15-Ready.xlsx
14568,14568,neuroblastoma,Neuroblastoma,432328008,Metastatic neuroblastoma,704152002,5yo nephew diagnosed with stage 4 neuroblastoma,https://www.reddit.com/3y34sd,./stage_1/batch 15-Ready.xlsx
784,784,CT,Computerized axial tomography,77477000,Computerized axial tomography,77477000,The amount of radiation in an airport xray is ...,https://www.reddit.com/9a7g6i,./stage_1/Batch6-Ready.xlsx


In [22]:
# generate dev and test
strat_dev_g, strat_test_g = train_test_split(
    strat_devtest_g,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=1/3
)

strat_dev_g.extend(small_dev)
strat_test_g.extend(small_test)

strat_dev_g = all_terms.loc[all_terms.index.isin(strat_dev_g)]
strat_test_g = all_terms.loc[all_terms.index.isin(strat_test_g)]

In [23]:
assert len(pd.concat([strat_train_g, strat_dev_g, strat_test_g])) == len(
    all_terms)

### Specific IDs

In [24]:
# remove IDs with less than 2 samples
big_generals = all_terms.groupby(['Specific SNOMED ID']).nunique()
small_generals = big_generals.loc[big_generals.ID < 2].index.values
big_generals = big_generals.loc[big_generals.ID > 1].index.values
# sanity check
assert len(small_generals) + \
    len(big_generals) == len(all_terms['Specific SNOMED ID'].unique())

print(f'There are {len(small_generals):4.0f} IDs with less than 2 labels and\n'
      f'          {len(big_generals):4.0f} IDs with at least 2.')

There are  266 IDs with less than 2 labels and
          3737 IDs with at least 2.


In [25]:
big_generals = all_terms.loc[all_terms['Specific SNOMED ID'].isin(
    big_generals)]

In [26]:
strat_train_s = []
strat_devtest_s = []
for row_ids in big_generals.groupby(['Specific SNOMED ID']).groups.values():
    strat_devtest_s.append(row_ids.values[0])
    strat_train_s.extend(row_ids.values[1:])

In [27]:
# Find the IDs of the small samples
small_generals = all_terms.loc[all_terms['Specific SNOMED ID'].isin(
    small_generals)].index.values

In [28]:
first_split = math.ceil(len(small_generals)*.7)
second_split = first_split + math.ceil((len(small_generals) - first_split) / 2)
second_split

# list
small_train = list(small_generals[:first_split])
small_dev = list(small_generals[first_split:second_split])
small_test = list(small_generals[second_split:])
assert len(small_train) + len(small_dev) + \
    len(small_test) == len(small_generals)

In [29]:
# Add most small examples
strat_train_s.extend(small_train)

strat_train_s = all_terms.loc[
    all_terms.index.isin(strat_train_s)]
strat_train_s = strat_train_s.sample(frac=1, random_state=RANDOM_SEED)
strat_train_s.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
4804,4804,androgens,Androgen,84629008,Androgen,84629008,PCOS is so confusing ( but DHEAs are androgens...,https://www.reddit.com/a6j6w8,./stage_1/Batch6-Ready.xlsx
6362,6362,budesonide,Budesonide,395726003,Budesonide,395726003,Should I ask for Budesonide as an alternative ...,https://www.reddit.com/9tpf17,./stage_1/Batch2-Ready.xlsx
3060,3060,POTS,Postural orthostatic tachycardia syndrome,371073003,Postural orthostatic tachycardia syndrome,371073003,I already do all of this because of POTS .,https://www.reddit.com/a32m26,./stage_1/Batch2-Ready.xlsx
11911,11911,implant,Implant,385286003,Dental implant system,468993001,"So based on this , potentially moving forward ...",https://www.reddit.com/a31qpn,./stage_1/Batch2-Ready.xlsx
13164,13164,liver dumps,Dawn phenomenon,398123003,Dawn phenomenon,398123003,Lots of Type 1 folks have found it easier to m...,https://www.reddit.com/8zu4tz,./stage_1/batch 12-Ready.xlsx


In [30]:
# generate dev and test
strat_dev_s, strat_test_s = train_test_split(
    strat_devtest_s,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=1/3
)

strat_dev_s.extend(small_dev)
strat_test_s.extend(small_test)

strat_dev_s = all_terms.loc[all_terms.index.isin(strat_dev_s)]
strat_test_s = all_terms.loc[all_terms.index.isin(strat_test_s)]

In [31]:
assert len(pd.concat([strat_train_s, strat_dev_s, strat_test_s])) == len(
    all_terms)

### Stats

In [32]:
print(
    f'Statified split for General IDs dataset sizes:\n'
    f'- Training set:    {len(strat_train_g):5.0f}\n'
    f'- Development set: {len(strat_dev_g):5.0f}\n'
    f'- Test set:        {len(strat_test_g):5.0f}')

print(
    f'Statified split for Specific IDs dataset sizes:\n'
    f'- Training set:    {len(strat_train_s):5.0f}\n'
    f'- Development set: {len(strat_dev_s):5.0f}\n'
    f'- Test set:        {len(strat_test_s):5.0f}')

Statified split for General IDs dataset sizes:
- Training set:    16377
- Development set:  1213
- Test set:         2425
Statified split for Specific IDs dataset sizes:
- Training set:    16199
- Development set:  1285
- Test set:         2531


In [33]:
strat_train_gids = set(strat_train_g['General SNOMED ID'].values)
strat_dev_gids = set(strat_dev_g['General SNOMED ID'].values)
strat_test_gids = set(strat_test_g['General SNOMED ID'].values)
strat_tt_overlap = 100 * \
    len(strat_train_gids.intersection(strat_test_gids))/len(general_ids)
print(
    f'There is a {strat_tt_overlap:2.2f}% overlap in General IDs between train and test set.')
strat_td_overlap = 100 * \
    len(strat_train_gids.intersection(strat_dev_gids))/len(general_ids)
print(
    f'There is a {strat_td_overlap:2.2f}% overlap in General IDs between train and dev set.')
strat_dt_overlap = 100 * \
    len(strat_test_gids.intersection(strat_dev_gids))/len(general_ids)
print(
    f'There is a {strat_dt_overlap:2.2f}% overlap in General IDs between test and dev set.')
strat_general_overlap = 100 * len(strat_train_gids.intersection(
    strat_dev_gids.intersection(strat_test_gids)))/len(general_ids)
print(
    f'There is a {strat_general_overlap:2.2f}% overlap in General IDs between train, dev and test sets.')


There is a 66.50% overlap in General IDs between train and test set.
There is a 33.22% overlap in General IDs between train and dev set.
There is a 0.00% overlap in General IDs between test and dev set.
There is a 0.00% overlap in General IDs between train, dev and test sets.


In [34]:
strat_train_sids = set(strat_train_s['Specific SNOMED ID'].values)
strat_dev_sids = set(strat_dev_s['Specific SNOMED ID'].values)
strat_test_sids = set(strat_test_s['Specific SNOMED ID'].values)
strat_tts_overlap = 100 * \
    len(strat_train_sids.intersection(strat_test_sids))/len(specific_ids)
print(
    f'There is a {strat_tts_overlap:2.2f}% overlap in Specific IDs between train and test set.')
strat_tds_overlap = 100 * \
    len(strat_train_sids.intersection(strat_dev_sids))/len(specific_ids)
print(
    f'There is a {strat_tds_overlap:2.2f}% overlap in Specific IDs between train and dev set.')
strat_dts_overlap = 100 * \
    len(strat_test_sids.intersection(strat_dev_sids))/len(specific_ids)
print(
    f'There is a {strat_dts_overlap:2.2f}% overlap in Specific IDs between test and dev set.')

strat_specific_overlap = 100 * len(strat_train_sids.intersection(
    strat_dev_sids.intersection(strat_test_sids)))/len(specific_ids)
print(
    f'There is a {strat_specific_overlap:2.2f}% overlap in Specific IDs between train, dev and test sets.')

There is a 62.25% overlap in Specific IDs between train and test set.
There is a 31.10% overlap in Specific IDs between train and dev set.
There is a 0.00% overlap in Specific IDs between test and dev set.
There is a 0.00% overlap in Specific IDs between train, dev and test sets.


In [35]:
print(f'In the stratified split,\n'
      f'- there are {strat_train_g.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the training set;\n'
      f'- there are {strat_dev_g.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the dev set;\n'
      f'- there are {strat_test_g.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the test set;\n'
      f'- there are {strat_train_s.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the training set;\n'
      f'- there are {strat_dev_s.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the dev set;\n'
      f'- there are {strat_test_s.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the test set;\n'
      )

In the stratified split,
- there are 4.50 terms per General ID in the training set;
- there are 1.00 terms per General ID in the dev set;
- there are 1.00 terms per General ID in the test set;
- there are 4.13 terms per Specific ID in the training set;
- there are 1.00 terms per Specific ID in the dev set;
- there are 1.00 terms per Specific ID in the test set;



### Save Data

In [36]:
strat_g_folder = os.path.join(OUTPUT_FOLDER, 'stratified_general')
os.makedirs(strat_g_folder, exist_ok=True)
strat_train_g.to_csv(os.path.join(strat_g_folder, 'train.csv'), sep='\t', encoding='utf8')
strat_dev_g.to_csv(os.path.join(strat_g_folder, 'dev.csv'), sep='\t', encoding='utf8')
strat_test_g.to_csv(os.path.join(strat_g_folder, 'test.csv'), sep='\t', encoding='utf8')

In [37]:
strat_s_folder = os.path.join(OUTPUT_FOLDER, 'stratified_specific')
os.makedirs(strat_s_folder, exist_ok=True)
strat_train_s.to_csv(os.path.join(strat_s_folder, 'train.csv'), sep='\t', encoding='utf8')
strat_dev_s.to_csv(os.path.join(strat_s_folder, 'dev.csv'), sep='\t', encoding='utf8')
strat_test_s.to_csv(os.path.join(strat_s_folder, 'test.csv'), sep='\t', encoding='utf8')

## Zero-Shot Split

### General IDs

In [38]:
train_gids, devtest_gids = train_test_split(
    general_ids,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=0.7
)

dev_gids, test_gids = train_test_split(
    devtest_gids,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=1/3
)

In [39]:
zeroshot_train_g = all_terms.loc[
    all_terms['General SNOMED ID'].isin(train_gids)
].sample(frac=1, random_state=RANDOM_SEED)
zeroshot_train_g.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
3708,3708,T2,Diabetes mellitus type 2,44054006,Diabetes mellitus type 2,44054006,I have a family history of T2 and diet / metfo...,https://www.reddit.com/8bb2yb,./stage_1/Batch1C-Ready.xlsx
11553,11553,hyperalgesia,Hyperalgesia,55406008,Hyperalgesia,55406008,"it ' s unfortunate , but the reality is opioid...",https://www.reddit.com/876850,./stage_1/Batch6-Ready.xlsx
11185,11185,herbal medicine,Herbal medicine,349365008,Herbal medicine,349365008,So far I know 2x paracetemol and 2x ibuprofen ...,https://www.reddit.com/7b1ofi,./stage_1/Batch 10-Ready.xlsx
1079,1079,DHT,Dihydrotestosterone,103042004,Dihydrotestosterone,103042004,What I ' m thinking is either a pituitary tumo...,https://www.reddit.com/a2gyvz,./stage_1/batch 12-Ready.xlsx
301,301,Arcoxia,Etoricoxib,409134009,Etoricoxib,409134009,Ibuprofen and Arcoxia ( Etoricoxib ).,https://www.reddit.com/42mtzc,./stage_1/batch 12-Ready.xlsx


In [40]:
zeroshot_dev_g = all_terms.loc[
    all_terms['General SNOMED ID'].isin(dev_gids)
]
zeroshot_dev_g.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
54,54,ADHD medication,Drug therapy for attention deficit hyperactivi...,702538008,Drug therapy for attention deficit hyperactivi...,702538008,20 - 40 something per cent in the UK and absol...,https://www.reddit.com/9skalb,./stage_1/Batch 7-Ready.xlsx
55,55,ADHD medication,Drug therapy for attention deficit hyperactivi...,702538008,Drug therapy for attention deficit hyperactivi...,702538008,I also grew up exploring the woods around my h...,https://www.reddit.com/6woxsp,./stage_1/Batch 7-Ready.xlsx
56,56,ADHD medication,Drug therapy for attention deficit hyperactivi...,702538008,Drug therapy for attention deficit hyperactivi...,702538008,I ' m still very skeptical about telling a doc...,https://www.reddit.com/98mpog,./stage_1/Batch 7-Ready.xlsx
57,57,ADHD meds,Drug therapy for attention deficit hyperactivi...,702538008,Drug therapy for attention deficit hyperactivi...,702538008,Does your ADHD meds affect your migraines ?,https://www.reddit.com/5oytzg,./stage_1/Batch 5-Ready.xlsx
58,58,ADHD meds,Drug therapy for attention deficit hyperactivi...,702538008,Drug therapy for attention deficit hyperactivi...,702538008,My husband is on ADHD meds and has trouble wit...,https://www.reddit.com/94y4bn,./stage_1/Batch 5-Ready.xlsx


In [41]:
zeroshot_test_g = all_terms.loc[
    all_terms['General SNOMED ID'].isin(test_gids)
]
zeroshot_test_g.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
6,6,6MP,Mercaptopurine,386835005,Mercaptopurine,386835005,We are up to 200 mg per day but my blood work ...,https://www.reddit.com/39qd6y,./stage_1/Batch2-Ready.xlsx
7,7,6MP,Mercaptopurine,386835005,Mercaptopurine,386835005,"Three years ago I was on Remicade , 6MP and Pr...",https://www.reddit.com/4cqh36,./stage_1/Batch2-Ready.xlsx
8,8,6MP,Mercaptopurine,386835005,Mercaptopurine,386835005,"Cimzia has failed now , to go along with the f...",https://www.reddit.com/7n0vir,./stage_1/Batch2-Ready.xlsx
12,12,A1C tests,Hemoglobin A1c measurement,43396009,Hemoglobin A1c measurement,43396009,I never test myself and only go by the A1c tes...,https://www.reddit.com/8w6hsc,./stage_1/batch 9-Ready.xlsx
13,13,A1C tests,Hemoglobin A1c measurement,43396009,Hemoglobin A1c measurement,43396009,Can A1c tests be trusted to diagnose diabetes ?,https://www.reddit.com/7uixud,./stage_1/batch 9-Ready.xlsx


In [42]:
assert len(pd.concat([zeroshot_train_g, zeroshot_dev_g, zeroshot_test_g])) == len(
    all_terms)

### Specific IDs

In [43]:
train_sids, devtest_sids = train_test_split(
    specific_ids,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=0.7
)

dev_sids, test_sids = train_test_split(
    devtest_sids,
    random_state=RANDOM_SEED,
    shuffle=True,
    train_size=1/3
)

In [44]:
zeroshot_train_s = all_terms.loc[
    all_terms['Specific SNOMED ID'].isin(train_sids)
].sample(frac=1, random_state=RANDOM_SEED)
zeroshot_train_s.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
16133,16133,proton pump inhibitor,Hydrogen/potassium adenosine triphosphatase en...,734582004,Hydrogen/potassium adenosine triphosphatase en...,734582004,It is important to pair these drugs with a Pro...,https://www.reddit.com/30x4nu,./stage_1/Batch 10-Ready.xlsx
18575,18575,testicular pain,Pain of bilateral testicles,16675251000119106,Pain of bilateral testicles,16675251000119106,My dad ( who is a complete skeptic ) found suc...,https://www.reddit.com/7ypmri,./stage_1/Batch6-Ready.xlsx
16736,16736,saline,Sodium chloride solution,373757009,Sodium chloride solution,373757009,More antibiotics and saline were pumped into m...,https://www.reddit.com/a3ky3i,./stage_1/Batch2-Ready.xlsx
2565,2565,Metamucil,Psyllium,52370008,Psyllium,52370008,I would then resume the Metamucil which would ...,https://www.reddit.com/8nemn1,./stage_1/batch 4-Ready.xlsx
5631,5631,bile duct,Bile duct structure,28273000,Bile duct structure,28273000,"They took the sample from the duodenum , which...",https://www.reddit.com/4pin1z,./stage_1/Batch8-Ready.xlsx


In [45]:
zeroshot_dev_s = all_terms.loc[
    all_terms['Specific SNOMED ID'].isin(dev_sids)
]
zeroshot_dev_s.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
27,27,ACE inhibitor,Angiotensin-converting enzyme inhibitor,734579009,Angiotensin-converting enzyme inhibitor,734579009,Just started an ace inhibitor this month .,https://www.reddit.com/46qhh6,./stage_1/batch 11-Ready.xlsx
28,28,ACE inhibitor,Angiotensin-converting enzyme inhibitor,734579009,Angiotensin-converting enzyme inhibitor,734579009,If it is simply hypertension and no other hear...,https://www.reddit.com/a7rdyq,./stage_1/batch 11-Ready.xlsx
29,29,ACE inhibitor,Angiotensin-converting enzyme inhibitor,734579009,Angiotensin-converting enzyme inhibitor,734579009,Talk to your doctor if stopping your ace inhib...,https://www.reddit.com/7w3wgh,./stage_1/batch 11-Ready.xlsx
45,45,ADD,"Attention deficit hyperactivity disorder, pred...",35253001,"Attention deficit hyperactivity disorder, pred...",35253001,I can ' t really tell if it ' s ADD or depress...,https://www.reddit.com/aadwbj,./stage_1/Batch2-Ready.xlsx
46,46,ADD,"Attention deficit hyperactivity disorder, pred...",35253001,"Attention deficit hyperactivity disorder, pred...",35253001,"I take adderall for ADD , so this is just simp...",https://www.reddit.com/3dwg0m,./stage_1/Batch2-Ready.xlsx


In [46]:
zeroshot_test_s = all_terms.loc[
    all_terms['Specific SNOMED ID'].isin(test_sids)
]
zeroshot_test_s.head()

Unnamed: 0,ID,Term,General SNOMED Label,General SNOMED ID,Specific SNOMED Label,Specific SNOMED ID,Example,Example Link,Origin_Sheet
6,6,6MP,Mercaptopurine,386835005,Mercaptopurine,386835005,We are up to 200 mg per day but my blood work ...,https://www.reddit.com/39qd6y,./stage_1/Batch2-Ready.xlsx
7,7,6MP,Mercaptopurine,386835005,Mercaptopurine,386835005,"Three years ago I was on Remicade , 6MP and Pr...",https://www.reddit.com/4cqh36,./stage_1/Batch2-Ready.xlsx
8,8,6MP,Mercaptopurine,386835005,Mercaptopurine,386835005,"Cimzia has failed now , to go along with the f...",https://www.reddit.com/7n0vir,./stage_1/Batch2-Ready.xlsx
12,12,A1C tests,Hemoglobin A1c measurement,43396009,Hemoglobin A1c measurement,43396009,I never test myself and only go by the A1c tes...,https://www.reddit.com/8w6hsc,./stage_1/batch 9-Ready.xlsx
13,13,A1C tests,Hemoglobin A1c measurement,43396009,Hemoglobin A1c measurement,43396009,Can A1c tests be trusted to diagnose diabetes ?,https://www.reddit.com/7uixud,./stage_1/batch 9-Ready.xlsx


In [47]:
assert len(pd.concat([zeroshot_train_s, zeroshot_dev_s, zeroshot_test_s])) == len(
    all_terms)

### Stats

In [48]:
print(
    f'Zero-Shot split for General IDs dataset sizes:\n'
    f'- Training set:    {len(zeroshot_train_g):5.0f}\n'
    f'- Development set: {len(zeroshot_dev_g):5.0f}\n'
    f'- Test set:        {len(zeroshot_test_g):5.0f}')

print(
    f'Zero-Shot split for Specific IDs dataset sizes:\n'
    f'- Training set:    {len(zeroshot_train_s):5.0f}\n'
    f'- Development set: {len(zeroshot_dev_s):5.0f}\n'
    f'- Test set:        {len(zeroshot_test_s):5.0f}')

Zero-Shot split for General IDs dataset sizes:
- Training set:    14062
- Development set:  1958
- Test set:         3995
Zero-Shot split for Specific IDs dataset sizes:
- Training set:    13714
- Development set:  2018
- Test set:         4283


In [49]:
zeroshot_train_gids = set(zeroshot_train_g['General SNOMED ID'].values)
zeroshot_dev_gids = set(zeroshot_dev_g['General SNOMED ID'].values)
zeroshot_test_gids = set(zeroshot_test_g['General SNOMED ID'].values)
zeroshot_tt_overlap = 100 * \
    len(zeroshot_train_gids.intersection(zeroshot_test_gids))/len(general_ids)
print(
    f'There is a {zeroshot_tt_overlap:2.2f}% overlap in General IDs between train and test set.')
zeroshot_td_overlap = 100 * \
    len(zeroshot_train_gids.intersection(zeroshot_dev_gids))/len(general_ids)
print(
    f'There is a {zeroshot_td_overlap:2.2f}% overlap in General IDs between train and dev set.')
zeroshot_dt_overlap = 100 * \
    len(zeroshot_test_gids.intersection(zeroshot_dev_gids))/len(general_ids)
print(
    f'There is a {zeroshot_dt_overlap:2.2f}% overlap in General IDs between test and dev set.')
zeroshot_general_overlap = 100 * len(zeroshot_train_gids.intersection(
    zeroshot_dev_gids.intersection(zeroshot_test_gids)))/len(general_ids)
print(
    f'There is a {zeroshot_general_overlap:2.2f}% overlap in General IDs between train, dev and test sets.')


There is a 0.00% overlap in General IDs between train and test set.
There is a 0.00% overlap in General IDs between train and dev set.
There is a 0.00% overlap in General IDs between test and dev set.
There is a 0.00% overlap in General IDs between train, dev and test sets.


In [50]:
zeroshot_train_sids = set(zeroshot_train_s['Specific SNOMED ID'].values)
zeroshot_dev_sids = set(zeroshot_dev_s['Specific SNOMED ID'].values)
zeroshot_test_sids = set(zeroshot_test_s['Specific SNOMED ID'].values)
zeroshot_tts_overlap = 100 * \
    len(zeroshot_train_sids.intersection(zeroshot_test_sids))/len(specific_ids)
print(
    f'There is a {zeroshot_tts_overlap:2.2f}% overlap in Specific IDs between train and test set.')
zeroshot_tds_overlap = 100 * \
    len(zeroshot_train_sids.intersection(zeroshot_dev_sids))/len(specific_ids)
print(
    f'There is a {zeroshot_tds_overlap:2.2f}% overlap in Specific IDs between train and dev set.')
zeroshot_dts_overlap = 100 * \
    len(zeroshot_test_sids.intersection(zeroshot_dev_sids))/len(specific_ids)
print(
    f'There is a {zeroshot_dts_overlap:2.2f}% overlap in Specific IDs between test and dev set.')

zeroshot_specific_overlap = 100 * len(zeroshot_train_sids.intersection(
    zeroshot_dev_sids.intersection(zeroshot_test_sids)))/len(specific_ids)
print(
    f'There is a {zeroshot_specific_overlap:2.2f}% overlap in Specific IDs between train, dev and test sets.')

There is a 0.00% overlap in Specific IDs between train and test set.
There is a 0.00% overlap in Specific IDs between train and dev set.
There is a 0.00% overlap in Specific IDs between test and dev set.
There is a 0.00% overlap in Specific IDs between train, dev and test sets.


In [51]:
print(f'In the zero shot split,\n'
      f'- there are {zeroshot_train_g.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the training set;\n'
      f'- there are {zeroshot_dev_g.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the dev set;\n'
      f'- there are {zeroshot_test_g.groupby(["General SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per General ID in the test set;\n'
      f'- there are {zeroshot_train_s.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the training set;\n'
      f'- there are {zeroshot_dev_s.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the dev set;\n'
      f'- there are {zeroshot_test_s.groupby(["Specific SNOMED ID"]).nunique().ID.mean():2.2f} \
terms per Specific ID in the test set;\n'
      )

In the zero shot split,
- there are 5.51 terms per General ID in the training set;
- there are 5.38 terms per General ID in the dev set;
- there are 5.47 terms per General ID in the test set;
- there are 4.89 terms per Specific ID in the training set;
- there are 5.04 terms per Specific ID in the dev set;
- there are 5.35 terms per Specific ID in the test set;



### Save Data

In [52]:
zeroshot_g_folder = os.path.join(OUTPUT_FOLDER, 'zeroshot_general')
os.makedirs(zeroshot_g_folder, exist_ok=True)
zeroshot_train_g.to_csv(os.path.join(zeroshot_g_folder, 'train.csv'), sep='\t', encoding='utf8')
zeroshot_dev_g.to_csv(os.path.join(zeroshot_g_folder, 'dev.csv'), sep='\t', encoding='utf8')
zeroshot_test_g.to_csv(os.path.join(zeroshot_g_folder, 'test.csv'), sep='\t', encoding='utf8')

In [53]:
zeroshot_s_folder = os.path.join(OUTPUT_FOLDER, 'zeroshot_specific')
os.makedirs(zeroshot_s_folder, exist_ok=True)
zeroshot_train_s.to_csv(os.path.join(zeroshot_s_folder, 'train.csv'), sep='\t', encoding='utf8')
zeroshot_dev_s.to_csv(os.path.join(zeroshot_s_folder, 'dev.csv'), sep='\t', encoding='utf8')
zeroshot_test_s.to_csv(os.path.join(zeroshot_s_folder, 'test.csv'), sep='\t', encoding='utf8')