# Stratified Sampling

Author: Kathryn Meldrum (kmm4ap@virginia.edu)

Code to find a train/dev/test split of the trials that retains a proportionate distribution of each entity

In [65]:
import pandas as pd
import os
import numpy as np

In [None]:
fail_list=[]
def ann_to_dict(nct_id): #took out file path arg
    '''
    nct_id: file name before .ann or .txt
    file_path: pathname to where .ann and .txt files are stored
    return: formatted dictionary
    '''
    
    # change directory
    #os.chdir(file_path)
    
    # read in files
    with open(nct_id+'.ann') as f1:
        ann=f1.read()
        f1.close()
    with open(nct_id+'.txt') as f2:
        txt=f2.read()
        f2.close()
        
    # ANN FILE MANIPULATION
    lines=ann.split('\n')
    ents=[]
    for i in range(len(lines)): 
        line=lines[i].split('\t') 
        if 'T' in line[0]:
            try: 
                start=txt.index(line[2])
                end= start+len(line[2])
                label=line[1].split(' ')[0]
                if label in label_list: 
                    add=True
                    for ent in ents: 
                        if (ent[0]<=start<=ent[1]) or (ent[0]<=end<=ent[1]) or (start<=ent[0]<=end) or (start<=ent[1]<=end): #check if it overlaps existing span
                            add=False 
                            if (start-end) > (ent[1]-ent[0]): #replace existing span if this span is bigger
                                ents.remove(ent)
                                ents.append((start, end, label))

                    if add==True: 
                        ents.append((start, end, label))
                else:
                    fail_list.append(label)
            
            except:
                None
                #print(nct_id, line)
            
    content={'entities': list(set(ents)), 'text': txt}
    
    return content

Get list of docs in chia corpus:

In [68]:
doc_list=[]
os.chdir('/Users/meldrumapple/Desktop/Capstone/chia_noscope_corpus') 
doc_list=os.listdir()
for i in range(len(doc_list)):
    doc_list[i]=doc_list[i][0:-4]
doc_list= list(set([x for x in doc_list if ('NCT') in x])) 
print(len(doc_list))

2000


In [69]:
label_list=['Person', 'Condition', 'Drug', 'Observation', 'Measurement', 'Procedure', 'Device', 'Visit', 'Negation', 'Qualifier', 'Temporal', 'Value', 'Multiplier', 'Mood', 'Informed_consent', 'Post-eligibility', 'Pregnancy_considerations', 'Reference_point']

Make a dict of counts for each entity in each doc:

In [70]:
NCT_dict={}
for d in doc_list:
    counts={'Total':0,'Person':0, 'Condition':0, 'Drug':0, 'Observation':0, 'Measurement':0, 'Procedure':0, 'Device':0, 'Visit':0, 'Negation':0, 'Qualifier':0, 'Temporal':0, 'Value':0, 'Multiplier':0, 'Mood':0, 'Informed_consent':0, 'Post-eligibility':0, 'Pregnancy_considerations':0, 'Reference_point':0}
    doc=ann_to_dict(d)
    if len(doc['entities'])>0: #sometimes only entities we don't care about in doc, remove
        for i in doc['entities']:
            counts[i[2]]+=1
            counts['Total']+=1
        NCT_dict[d]=np.array(list(counts.values()))
    else: 
        doc_list.remove(d)
print(len(doc_list))
print(len(NCT_dict.keys()))

1941
1882


In [116]:
# here something wonky happened and 60 docs didn't get added to the dict for no reason, so I fix below
# The other 60 missing didn't have any entities from our list of keepers

In [74]:
fails=[x for x in doc_list if x not in list(NCT_dict.keys())]
len(fails)

59

In [None]:
for d in fails:
    counts={'Total':0,'Person':0, 'Condition':0, 'Drug':0, 'Observation':0, 'Measurement':0, 'Procedure':0, 'Device':0, 'Visit':0, 'Negation':0, 'Qualifier':0, 'Temporal':0, 'Value':0, 'Multiplier':0, 'Mood':0, 'Informed_consent':0, 'Post-eligibility':0, 'Pregnancy_considerations':0, 'Reference_point':0}
    doc=ann_to_dict(d)
    if len(doc['entities'])>0: #sometimes only entities we don't care about in doc, remove
        for i in doc['entities']:
            counts[i[2]]+=1
            counts['Total']+=1
        #print(counts)
        NCT_dict[d]=np.array(list(counts.values()))
    else: 
        doc_list.remove(d)

In [79]:
print(len(doc_list))
print(len(NCT_dict.keys())) #IDK why that happened but now all doc_list trials are in NCT_dict

1940
1940


Pick the best of 10000 random splits:

In [80]:
import random

In [125]:
# percent error
best_score_pe=300
for i in range(10000):
    l=len(doc_list)
    test_idx=random.sample(doc_list, 3*(l//10))
    train_idx = [x for x in doc_list if x not in test_idx]
    dev_idx=random.sample(train_idx, 3*(l//10))
    train_idx = [x for x in train_idx if x not in dev_idx]

    test_split=np.array([0]+[0 for i in range(len(label_list))])
    for d in test_idx:
        test_split+=NCT_dict[d]
    test_goal=(dev_split+train_split+test_split)*0.30
    test_score=np.sum(np.absolute(test_split-test_goal)/test_goal)

    train_split=np.array([0]+[0 for i in range(len(label_list))])
    for d in train_idx:
        train_split+=NCT_dict[d]
    train_goal=(dev_split+train_split+test_split)*0.49
    train_score=np.sum(np.absolute(train_split-train_goal)/train_goal)

    dev_split=np.array([0]+[0 for i in range(len(label_list))])
    for d in dev_idx:
        dev_split+=NCT_dict[d]
    dev_goal=(dev_split+train_split+test_split)*0.21
    dev_score=np.sum(np.absolute(dev_split-dev_goal)/dev_goal)

    total_score=test_score+train_score+dev_score
    
    if total_score<best_score_pe: 
        best_score_pe=total_score
        best_idx_pe={'train':train_idx,'dev':dev_idx, 'test':test_idx}
        best_splits_pe={'labels':['Total']+label_list, 'train':train_split,'dev':dev_split, 'test':test_split}

totals=best_splits_pe['train']+best_splits_pe['dev']+best_splits_pe['test']
best_splits_pe['goal_train']=totals*0.49
best_splits_pe['goal_dev']=totals*0.21
best_splits_pe['goal_test']=totals*0.3

pd.DataFrame(best_splits_pe)

Unnamed: 0,labels,train,dev,test,goal_train,goal_dev,goal_test
0,Total,14828,9416,10999,17269.07,7401.03,10572.9
1,Person,578,422,458,714.42,306.18,437.4
2,Condition,4269,2837,3289,5093.55,2182.95,3118.5
3,Drug,1361,933,989,1608.67,689.43,984.9
4,Observation,442,271,333,512.54,219.66,313.8
5,Measurement,1256,794,925,1457.75,624.75,892.5
6,Procedure,1278,808,892,1459.22,625.38,893.4
7,Device,186,66,77,161.21,69.09,98.7
8,Visit,52,40,47,68.11,29.19,41.7
9,Negation,290,186,212,337.12,144.48,206.4


In [120]:
# raw error: 
best_score_raw=100000
for i in range(10000):
    l=len(doc_list)
    test_idx=random.sample(doc_list, 3*(l//10))
    train_idx = [x for x in doc_list if x not in test_idx]
    dev_idx=random.sample(train_idx, 3*(l//10))
    train_idx = [x for x in train_idx if x not in dev_idx]

    test_split=np.array([0]+[0 for i in range(len(label_list))])
    for d in test_idx:
        test_split+=NCT_dict[d]
    test_goal=(dev_split+train_split+test_split)*0.30
    test_score=np.sum(np.absolute(test_split-test_goal))

    train_split=np.array([0]+[0 for i in range(len(label_list))])
    for d in train_idx:
        train_split+=NCT_dict[d]
    train_goal=(dev_split+train_split+test_split)*0.49
    train_score=np.sum(np.absolute(train_split-train_goal))

    dev_split=np.array([0]+[0 for i in range(len(label_list))])
    for d in dev_idx:
        dev_split+=NCT_dict[d]
    dev_goal=(dev_split+train_split+test_split)*0.21
    dev_score=np.sum(np.absolute(dev_split-dev_goal))

    total_score=test_score+train_score+dev_score
    
    if total_score<best_score_raw: 
        best_score_raw=total_score
        best_idx_raw={'train':train_idx,'dev':dev_idx, 'test':test_idx}
        best_splits_raw={'labels':['Total']+label_list, 'train':train_split,'dev':dev_split, 'test':test_split}

totals=best_splits_raw['train']+best_splits_raw['dev']+best_splits_raw['test']
best_splits_raw['goal_train']=totals*0.49
best_splits_raw['goal_dev']=totals*0.21
best_splits_raw['goal_test']=totals*0.3

print('best_score: ', best_score_raw)
print('best_splits:')
pd.DataFrame(best_splits_raw)

best_score:  8759.68
best_splits:


Unnamed: 0,labels,train,dev,test,goal_train,goal_dev,goal_test
0,Total,15516,9274,10453,17269.07,7401.03,10572.9
1,Person,589,411,458,714.42,306.18,437.4
2,Condition,4582,2698,3115,5093.55,2182.95,3118.5
3,Drug,1457,853,973,1608.67,689.43,984.9
4,Observation,465,282,299,512.54,219.66,313.8
5,Measurement,1401,762,812,1457.75,624.75,892.5
6,Procedure,1256,845,877,1459.22,625.38,893.4
7,Device,154,79,96,161.21,69.09,98.7
8,Visit,50,32,57,68.11,29.19,41.7
9,Negation,301,165,222,337.12,144.48,206.4


In [127]:
print(best_idx_pe)

{'train': ['NCT03187379_exc', 'NCT02735577_inc', 'NCT01803438_inc', 'NCT03104816_inc', 'NCT02348918_exc', 'NCT02607319_inc', 'NCT02053246_exc', 'NCT03363295_exc', 'NCT00425789_exc', 'NCT02015494_exc', 'NCT02964715_inc', 'NCT01579604_exc', 'NCT00401245_exc', 'NCT02604459_exc', 'NCT01118871_inc', 'NCT01942915_exc', 'NCT03131050_exc', 'NCT02888704_inc', 'NCT02721017_exc', 'NCT03255044_exc', 'NCT02567214_inc', 'NCT00806936_inc', 'NCT02816762_inc', 'NCT01531257_inc', 'NCT02273791_exc', 'NCT00461136_inc', 'NCT03212352_inc', 'NCT02579733_exc', 'NCT02536976_exc', 'NCT02365870_inc', 'NCT01774019_inc', 'NCT03249311_exc', 'NCT02600000_exc', 'NCT01084993_inc', 'NCT02385045_exc', 'NCT02282319_exc', 'NCT02464813_exc', 'NCT02951754_exc', 'NCT00785213_exc', 'NCT03199560_inc', 'NCT02653131_inc', 'NCT01261832_inc', 'NCT03541980_exc', 'NCT02746900_inc', 'NCT02368743_inc', 'NCT02695992_inc', 'NCT02429583_exc', 'NCT02762851_exc', 'NCT02755701_exc', 'NCT01349413_exc', 'NCT03208465_exc', 'NCT02396732_inc', '

In [128]:
print(best_idx_raw)

{'train': ['NCT00379366_inc', 'NCT03340740_exc', 'NCT02466113_inc', 'NCT01669369_inc', 'NCT02348918_exc', 'NCT02607319_inc', 'NCT02053246_exc', 'NCT03376763_inc', 'NCT03363295_exc', 'NCT02015494_exc', 'NCT01696617_exc', 'NCT03444142_exc', 'NCT00639795_inc', 'NCT02361892_inc', 'NCT03115320_inc', 'NCT02406885_inc', 'NCT02323399_inc', 'NCT01118871_inc', 'NCT03216967_inc', 'NCT02877485_exc', 'NCT02721017_exc', 'NCT03518034_inc', 'NCT02429765_inc', 'NCT00461136_inc', 'NCT03335904_exc', 'NCT02536976_exc', 'NCT02365870_inc', 'NCT02056301_inc', 'NCT01774019_inc', 'NCT03249311_exc', 'NCT02600000_exc', 'NCT02385045_exc', 'NCT02579928_inc', 'NCT02282319_exc', 'NCT02924090_exc', 'NCT03043495_exc', 'NCT02570230_inc', 'NCT03315975_exc', 'NCT02954029_exc', 'NCT03195153_exc', 'NCT03080493_inc', 'NCT03318874_exc', 'NCT02368743_inc', 'NCT02692651_exc', 'NCT03388840_inc', 'NCT02034019_exc', 'NCT02035904_inc', 'NCT02498483_exc', 'NCT02762851_exc', 'NCT02838810_inc', 'NCT02951520_inc', 'NCT02827526_inc', '