## Assign Candidate Splits

Import annotations from various sources and use the presence of annoations for freshly imported candidates as a means to move them into appropriate splits for modeling.

More specifically:

- BRAT annotations will all be moved to the ```dev``` split
- All other annotations will be split randomly betweeen the ```val``` and ```test``` splits
- A ```train``` split will also be randomly extracted from the initial ```infer``` split

In [1]:
import os
import os.path as osp
import pandas as pd
import numpy as np
import collections
import tqdm
from tcre.env import *
from tcre.supervision import *
from snorkel import SnorkelSession
from snorkel.models import Document, Sentence, Candidate, GoldLabel, GoldLabelKey, StableLabel
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels
session = SnorkelSession()
classes = get_candidate_classes()

### Reset Annotations and Splits

This will delete any gold labels as well as move all canidates on splits with gold labels back to the INFER split:

In [2]:
for c in [StableLabel, GoldLabel, GoldLabelKey]:
    print(c.__name__, session.query(c).count())

StableLabel 3054
GoldLabel 2853
GoldLabelKey 9


In [3]:
# Clear everything relating to existing labels
session.query(StableLabel).delete()
session.query(GoldLabel).delete()
session.query(GoldLabelKey).delete()
session.commit()

In [4]:
df = pd.DataFrame(session.query(Candidate.id, Candidate.split).filter(Candidate.split.in_([SPLIT_DEV, SPLIT_VAL, SPLIT_TEST])).all())
print(df.info())
if 'split' in df:
    df.groupby('split').size()

<class 'pandas.core.frame.DataFrame'>
Index: 0 entries
Empty DataFrameNone


In [5]:
for i, r in df.iterrows():
    cand = session.query(Candidate).filter(Candidate.id == int(r['id'])).one()
    cand.split = SPLIT_INFER
session.commit()   

In [6]:
del df

### Load Raw Annotations

In [7]:
def load_labels_01():
    """Load labels from BRAT (dev split)"""
    # Note: Labels are exclusive on ends of char range
    path = osp.join(REPO_DATA_DIR, 'annotation', 'brat_export.csv')
    df = pd.read_csv(path)
    for i in [1, 2]:
        # Convert to inclusive end range
        df['e{}_end_chr'.format(i)] = df['e{}_end_chr'.format(i)] - 1
        assert (df['e{}_end_chr'.format(i)] > df['e{}_start_chr'.format(i)]).all()
    df = df.rename(columns={'id': 'doc_id'})
    df['type'] = df['rel_typ'].map({
        'Induction': REL_FIELD_INDUCING_CYTOKINE,
        'Secretion': REL_FIELD_SECRETED_CYTOKINE,
        'Differentiation': REL_FIELD_INDUCING_TRANSCRIPTION_FACTOR
    })
    assert df['type'].notnull().all()
    # All of these annotations are only positive
    df['value'] = 1
    df['annotator'] = 'brat'
    df = df[[
        'doc_id', 'type', 'e1_text', 'e1_start_chr', 'e1_end_chr', 
        'e2_text', 'e2_start_chr', 'e2_end_chr', 'value', 'annotator'
    ]]
    return df
df1 = load_labels_01()
df1.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 392 entries, 0 to 391
Data columns (total 10 columns):
doc_id          392 non-null object
type            392 non-null object
e1_text         392 non-null object
e1_start_chr    392 non-null int64
e1_end_chr      392 non-null int64
e2_text         392 non-null object
e2_start_chr    392 non-null int64
e2_end_chr      392 non-null int64
value           392 non-null int64
annotator       392 non-null object
dtypes: int64(5), object(5)
memory usage: 30.7+ KB


In [8]:
df1.head()

Unnamed: 0,doc_id,type,e1_text,e1_start_chr,e1_end_chr,e2_text,e2_start_chr,e2_end_chr,value,annotator
0,PMC4451961,secreted_cytokine,IL-17,1338,1342,Vγ4 T,1291,1295,1,brat
1,PMC4451961,secreted_cytokine,IL-17,1460,1464,Vγ4 T,1291,1295,1,brat
2,PMC4451961,secreted_cytokine,IL-17,1903,1907,Vγ4 T,1850,1854,1,brat
3,PMC4451961,secreted_cytokine,(IFN)-γ,3730,3736,γδ T,3825,3828,1,brat
4,PMC4451961,secreted_cytokine,(IL)-17,3754,3760,γδ T,3825,3828,1,brat


In [9]:
def load_labels_02():
    """Load labels from SnorkelNgramViewer annotation"""
    # Note: Labels are already inclusive on ends of char range
    path = osp.join(REPO_DATA_DIR, 'annotation', 'ngramviewer_export.csv')
    df = pd.read_csv(path)
    df = df[df['split'] == 3]
    for c in ['c1_char_start', 'c1_char_end', 'c2_char_start', 'c2_char_end']:
        df[c] = df[c] + df['sent_char_offset']
    df = df[['doc_id', 'field', 'c1_text', 'c1_char_start', 'c1_char_end', 'c2_text', 'c2_char_start', 'c2_char_end', 'value']]
    df = df.rename(columns=lambda c: c.replace('c1', 'e1').replace('c2', 'e2'))
    df = df.rename(columns=lambda c: c.replace('char_start', 'start_chr').replace('char_end', 'end_chr'))
    df = df.rename(columns={'field': 'type'})
    df['annotator'] = 'snorkelngramviewer'
    return df
df2 = load_labels_02()
df2.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 612 entries, 300 to 911
Data columns (total 10 columns):
doc_id          612 non-null object
type            612 non-null object
e1_text         612 non-null object
e1_start_chr    612 non-null int64
e1_end_chr      612 non-null int64
e2_text         612 non-null object
e2_start_chr    612 non-null int64
e2_end_chr      612 non-null int64
value           612 non-null int64
annotator       612 non-null object
dtypes: int64(5), object(5)
memory usage: 52.6+ KB


In [10]:
df2.head()

Unnamed: 0,doc_id,type,e1_text,e1_start_chr,e1_end_chr,e2_text,e2_start_chr,e2_end_chr,value,annotator
300,PMC5799164,inducing_cytokine,IL-12,2956,2960,cytotoxic T,2753,2763,-1,snorkelngramviewer
301,PMC3378591,inducing_cytokine,IFN-γ,50143,50147,Th1,50014,50016,-1,snorkelngramviewer
302,PMC6279938,inducing_cytokine,IFN-γ,51442,51446,Th1,51465,51467,-1,snorkelngramviewer
303,PMC2646562,inducing_cytokine,IL-6,17571,17574,Th17,17637,17640,-1,snorkelngramviewer
304,PMC3486158,inducing_cytokine,IFN-γ,38187,38191,NKT,38130,38132,-1,snorkelngramviewer


In [11]:
def load_labels_03():
    """Load labels from Doccano annotation"""
    # The start/end character ranges are passed through docanno from snorkel
    # as metadata, so they are already inclusive on the right
    path = osp.join(REPO_DATA_DIR, 'annotation', 'doccano_export.csv')
    df = pd.read_csv(path)
    df = df[[
        'doc_id', 'type', 
        'e1_text', 'e1_start_chr', 'e1_end_chr', 
        'e2_text', 'e2_start_chr', 'e2_end_chr',
        'value', 'sent_abs_offset'
    ]]
    chr_cols = df.filter(regex='(start|end)_chr').columns.tolist()
    # Add character offset of first char in sentence since snorkel provides
    # start/end chars as relative to sentence start
    for c in chr_cols:
        df[c] = df[c] + df['sent_abs_offset']
    df = df.drop('sent_abs_offset', axis=1)
    df['annotator'] = 'doccano'
    return df
df3 = load_labels_03()
df3.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2050 entries, 0 to 2049
Data columns (total 10 columns):
doc_id          2050 non-null object
type            2050 non-null object
e1_text         2050 non-null object
e1_start_chr    2050 non-null int64
e1_end_chr      2050 non-null int64
e2_text         2050 non-null object
e2_start_chr    2050 non-null int64
e2_end_chr      2050 non-null int64
value           2050 non-null int64
annotator       2050 non-null object
dtypes: int64(5), object(5)
memory usage: 160.2+ KB


In [12]:
df3.head()

Unnamed: 0,doc_id,type,e1_text,e1_start_chr,e1_end_chr,e2_text,e2_start_chr,e2_end_chr,value,annotator
0,PMC4224555,secreted_cytokine,IL-12,4171,4175,Th2,4143,4145,-1,doccano
1,PMC4224555,inducing_cytokine,IL-12,4171,4175,Th2,4143,4145,-1,doccano
2,PMC5818395,secreted_cytokine,Tr1,558,560,effector T,761,770,-1,doccano
3,PMC5818395,inducing_cytokine,Tr1,558,560,effector T,761,770,-1,doccano
4,PMC5634439,secreted_cytokine,IL-23,18873,18877,Th17,18604,18607,-1,doccano


## Merge

In [13]:
def get_stable_id(r):
    return '{}::span:{}:{}'.format(r['doc_id'], r['start_chr'], r['end_chr'])

def get_context_stable_id(r):
    e1_id = get_stable_id(r.filter(regex='^e1_|^doc_id$').rename(lambda v: v.replace('e1_', '')))
    e2_id = get_stable_id(r.filter(regex='^e2_|^doc_id$').rename(lambda v: v.replace('e2_', '')))
    return "~~".join([e1_id, e2_id])

df = pd.concat([df1, df2, df3])
df['context_stable_ids'] = df.apply(get_context_stable_id, axis=1)
df.head()

Unnamed: 0,doc_id,type,e1_text,e1_start_chr,e1_end_chr,e2_text,e2_start_chr,e2_end_chr,value,annotator,context_stable_ids
0,PMC4451961,secreted_cytokine,IL-17,1338,1342,Vγ4 T,1291,1295,1,brat,PMC4451961::span:1338:1342~~PMC4451961::span:1...
1,PMC4451961,secreted_cytokine,IL-17,1460,1464,Vγ4 T,1291,1295,1,brat,PMC4451961::span:1460:1464~~PMC4451961::span:1...
2,PMC4451961,secreted_cytokine,IL-17,1903,1907,Vγ4 T,1850,1854,1,brat,PMC4451961::span:1903:1907~~PMC4451961::span:1...
3,PMC4451961,secreted_cytokine,(IFN)-γ,3730,3736,γδ T,3825,3828,1,brat,PMC4451961::span:3730:3736~~PMC4451961::span:3...
4,PMC4451961,secreted_cytokine,(IL)-17,3754,3760,γδ T,3825,3828,1,brat,PMC4451961::span:3754:3760~~PMC4451961::span:3...


### Import Labels

Set labels for candidates within DB 

In [14]:
def load_labels(session, relations, annotator, candidate_class, split):
    assert (relations['type'] == candidate_class.field).all(), 'Relations data frame has conflicting relation classes'
    assert (relations['annotator'] == annotator).all(), 'Relations data frame has conflicting annotator'
    annotator_name = annotator + '-' + candidate_class.field
    
    for i, r in relations.iterrows():
        value = int(r['value'])
        assert value in [-1, 1], 'Row has invalid relation value: {}'.format(r)
        context_stable_ids = r['context_stable_ids']
        
        query = session.query(StableLabel)\
            .filter(StableLabel.context_stable_ids == context_stable_ids)\
            .filter(StableLabel.annotator_name == annotator_name)
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=value
            ))
            
    session.commit()
    # This function will create GoldLabel records for each StableLabel above after
    # selecting them based on annotator_name.  The annotator name should be different
    # for each candidate class if they might have identical context_stable_ids since
    # otherwise as written above only the first value for the same annotator + context_stable_ids
    # will be saved.
    # Other notes: split will be used to find candidates necessary to create GoldLabels though
    # it is not necessary for StableLabel filtering in this case (thus filter_label_split=False)
    # because the labels were not created with a split above
    return reload_annotator_labels(
        session, candidate_class.subclass, annotator_name, split=split, 
        filter_label_split=False, create_missing_cands=False)

In [15]:
labels = collections.OrderedDict()
clmap = {classes[c].field: classes[c] for c in classes}
grps = df.groupby(['type', 'annotator'])
    
for k, g in grps:
    candidate_class = clmap[k[0]]
    annotator = k[1]
    print('Loading labels for class={}, annotator={}'.format(*k))
    # Returns tuple with:
    # 0: List of GoldLabel instances added
    # 1: List of StableLabel instances for which no candidate could be found
    labels[k] = load_labels(session, g, annotator, candidate_class, SPLIT_INFER)

Loading labels for class=inducing_cytokine, annotator=brat
AnnotatorLabels created: 123, missed: 22
Loading labels for class=inducing_cytokine, annotator=doccano
AnnotatorLabels created: 750, missed: 0
Loading labels for class=inducing_cytokine, annotator=snorkelngramviewer
AnnotatorLabels created: 177, missed: 20
Loading labels for class=inducing_transcription_factor, annotator=brat
AnnotatorLabels created: 64, missed: 31
Loading labels for class=inducing_transcription_factor, annotator=doccano
AnnotatorLabels created: 550, missed: 0
Loading labels for class=inducing_transcription_factor, annotator=snorkelngramviewer
AnnotatorLabels created: 199, missed: 18
Loading labels for class=secreted_cytokine, annotator=brat
AnnotatorLabels created: 82, missed: 70
Loading labels for class=secreted_cytokine, annotator=doccano
AnnotatorLabels created: 724, missed: 26
Loading labels for class=secreted_cytokine, annotator=snorkelngramviewer
AnnotatorLabels created: 184, missed: 14


In [16]:
# Show missing candidates
from snorkel.models import Context 

# Show which entity types and relations were unable to be matched with spans
# extracted and inserted into snorkel db
def summarize_missing_candidates(labels):
    df = []
    for k in labels:
        candidate_class = clmap[k[0]]
        class_name = k[0]
        missed = labels[k][1]
        for mc in missed:
            doc_id = mc.context_stable_ids.split('::')[0]
            ids = mc.context_stable_ids.split('~~')
            typs = []
            for i, sid in enumerate(ids):
                ctx = session.query(Context).filter(Context.stable_id == sid).all()
                if len(ctx) == 0:
                    typs.append(candidate_class.entity_types[i])
            df.append((class_name, doc_id, ','.join(typs)))
    return pd.DataFrame(df, columns=['relation', 'doc_id', 'missing'])
df_miss = summarize_missing_candidates(labels)
df_miss.groupby(['relation', 'missing']).size()

relation                       missing                              
inducing_cytokine                                                        8
                               cytokine                                 17
                               immune_cell_type                         17
inducing_transcription_factor                                            7
                               immune_cell_type                         10
                               transcription_factor                     25
                               transcription_factor,immune_cell_type     7
secreted_cytokine                                                       34
                               cytokine                                 20
                               cytokine,immune_cell_type                 2
                               immune_cell_type                         54
dtype: int64

### Debug Missing Candidates

In [17]:
# ml1 = labels[('secreted_cytokine', 'brat')][1][1]
# ml1.context_stable_ids

In [53]:
# df[df['context_stable_ids'] == ml1.context_stable_ids]

In [49]:
# doc = session.query(Document).filter(Document.name == 'PMC4451961').one()
# sents = doc.sentences

In [50]:
# dftkn = pd.concat([
#     pd.DataFrame(list(zip(sent.entity_types, sent.abs_char_offsets, sent.words)))
#     for sent in sents[:32]
# ])
# pd.set_option('display.max_rows', 10000)
# # Merge to results from misc/missing-cand-debug.py
# pd.concat([
#     dftkn.reset_index(drop=True), pd.read_csv('/tmp/tokens.csv')
# ], axis=1)

### Assign Splits

In [18]:
# Create data frame containing ALL candidates with a gold label
dfg = pd.DataFrame([
    dict(
        type=lbl.candidate.type, 
        annotator=lbl.key.name.split('-')[0], 
        annotator_name=lbl.key.name, 
        cand_id=lbl.candidate.id, 
        value=lbl.value
    )
    for lbl in session.query(GoldLabel).all()
])
dfg.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2853 entries, 0 to 2852
Data columns (total 5 columns):
annotator         2853 non-null object
annotator_name    2853 non-null object
cand_id           2853 non-null int64
type              2853 non-null object
value             2853 non-null int64
dtypes: int64(2), object(3)
memory usage: 111.5+ KB


In [19]:
dfg.groupby(['type', 'annotator', 'annotator_name']).size()

type                           annotator           annotator_name                                  
inducing_cytokine              brat                brat-inducing_cytokine                              123
                               doccano             doccano-inducing_cytokine                           750
                               snorkelngramviewer  snorkelngramviewer-inducing_cytokine                177
inducing_transcription_factor  brat                brat-inducing_transcription_factor                   64
                               doccano             doccano-inducing_transcription_factor               550
                               snorkelngramviewer  snorkelngramviewer-inducing_transcription_factor    199
secreted_cytokine              brat                brat-secreted_cytokine                               82
                               doccano             doccano-secreted_cytokine                           724
                               snorkelngramv

In [20]:
from tcre.modeling import sampling

def assign_split(g):
    assert g['annotator'].nunique() == 1
    annotator = g['annotator'].iloc[0]
    
    # Assign all BRAT annotations to dev
    if annotator == 'brat':
        return g.assign(split=SPLIT_DEV)
    
    # Do stratified assignment of ngramviewer annotations to val and test
    if annotator in ['snorkelngramviewer', 'doccano']:
        proportions = [
            .3, # dev
            .3, # val
            .4  # test
        ]
        grps = sampling.get_stratified_split(g['value'].values, proportions)
        splits = np.zeros(len(g), dtype=np.int64)
        splits[grps[0]] = SPLIT_DEV
        splits[grps[1]] = SPLIT_VAL
        splits[grps[2]] = SPLIT_TEST
        return g.assign(split=splits)
    raise ValueError('Annotator {} not supported'.format(annotator))
    
dfgs = pd.concat([assign_split(g) for k, g in dfg.groupby(['annotator', 'type'])])
assert (dfgs['split'] > 0).all()
assert dfgs['split'].dtype == np.int64
dfgs.head()

Unnamed: 0,annotator,annotator_name,cand_id,type,value,split
0,brat,brat-inducing_cytokine,71578,inducing_cytokine,1,1
1,brat,brat-inducing_cytokine,71576,inducing_cytokine,1,1
2,brat,brat-inducing_cytokine,71583,inducing_cytokine,1,1
3,brat,brat-inducing_cytokine,71581,inducing_cytokine,1,1
4,brat,brat-inducing_cytokine,82291,inducing_cytokine,1,1


In [21]:
dfgs.groupby(['annotator', 'type', 'value']).size().unstack()

Unnamed: 0_level_0,value,-1,1
annotator,type,Unnamed: 2_level_1,Unnamed: 3_level_1
brat,inducing_cytokine,,123.0
brat,inducing_transcription_factor,,64.0
brat,secreted_cytokine,,82.0
doccano,inducing_cytokine,701.0,49.0
doccano,inducing_transcription_factor,443.0,107.0
doccano,secreted_cytokine,598.0,126.0
snorkelngramviewer,inducing_cytokine,152.0,25.0
snorkelngramviewer,inducing_transcription_factor,161.0,38.0
snorkelngramviewer,secreted_cytokine,148.0,36.0


In [22]:
dfgs.groupby(['annotator', 'type', 'split', 'value']).size().unstack()

Unnamed: 0_level_0,Unnamed: 1_level_0,value,-1,1
annotator,type,split,Unnamed: 3_level_1,Unnamed: 4_level_1
brat,inducing_cytokine,1,,123.0
brat,inducing_transcription_factor,1,,64.0
brat,secreted_cytokine,1,,82.0
doccano,inducing_cytokine,1,210.0,15.0
doccano,inducing_cytokine,2,281.0,19.0
doccano,inducing_cytokine,3,210.0,15.0
doccano,inducing_transcription_factor,1,133.0,32.0
doccano,inducing_transcription_factor,2,177.0,43.0
doccano,inducing_transcription_factor,3,133.0,32.0
doccano,secreted_cytokine,1,179.0,38.0


In [23]:
pd.concat([
    dfgs.groupby(['split'])['value'].value_counts().rename('count'),
    dfgs.groupby(['split'])['value'].value_counts(normalize=True).rename('percent')
], axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent
split,value,Unnamed: 2_level_1,Unnamed: 3_level_1
1,-1,660,0.63279
1,1,383,0.36721
2,-1,884,0.854106
2,1,151,0.145894
3,-1,659,0.850323
3,1,116,0.149677


### Move Candidates

In [25]:
from tcre import query as tcre_query
df_cand_doc = tcre_query.DocToCand.all(session, classes)
assert (df_cand_doc.groupby('cand_id')['doc_id'].nunique() == 1).all()
df_cand_doc.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 220213 entries, 0 to 36534
Data columns (total 4 columns):
doc_id         220213 non-null int64
sentence_id    220213 non-null int64
cand_id        220213 non-null int64
cand_type      220213 non-null object
dtypes: int64(3), object(1)
memory usage: 8.4+ MB


In [26]:
def get_transitive_candidates(df):
    # For BRAT candidate annotations, the entire document should be moved and not just the candidate 
    annotator = df['annotator'].iloc[0]
    split = df['split'].iloc[0]
    
    # cand_id -> doc_id
    cand_to_doc = df_cand_doc.set_index('cand_id')['doc_id']
    # doc_id -> [cand_id]
    doc_to_cands = df_cand_doc.groupby('doc_id')['cand_id'].unique()
    
    doc_ids = cand_to_doc[df['cand_id'].unique()].unique()
    cand_ids = set([cid for cand_ids in doc_to_cands.loc[list(doc_ids)] for cid in cand_ids])
    return pd.DataFrame(cand_ids, columns=['cand_id']).assign(split=split, annotator=annotator)
    
    
df_mv = pd.concat([
    get_transitive_candidates(dfgs[dfgs['annotator'] == 'brat'])[['annotator', 'cand_id', 'split']],
    dfgs[dfgs['annotator'] != 'brat'][['annotator', 'cand_id', 'split']]
])
df_mv.groupby(['annotator', 'split']).size()

annotator           split
brat                1        1897
doccano             1         607
                    2         810
                    3         607
snorkelngramviewer  1         167
                    2         225
                    3         168
dtype: int64

In [27]:
df_mv.head()

Unnamed: 0,annotator,cand_id,split
0,brat,73847,1
1,brat,73848,1
2,brat,73849,1
3,brat,73850,1
4,brat,73851,1


In [28]:
for i, r in df_mv.iterrows():
    cand = session.query(Candidate).filter(Candidate.id == int(r['cand_id'])).one()
    cand.split = r['split']
session.commit()    

## Validate

Check that no more gold labels are associated with candidates on the initial INFER split:

In [29]:
df_gl = pd.DataFrame([
    dict(split=gl.candidate.split, type=gl.candidate.type, value=gl.value)
    for gl in session.query(GoldLabel).all()
])
df_gl.groupby(['type', 'split', 'value']).size()

type                           split  value
inducing_cytokine              1      -1       256
                                       1       145
                               2      -1       342
                                       1        29
                               3      -1       255
                                       1        23
inducing_transcription_factor  1      -1       181
                                       1       107
                               2      -1       242
                                       1        58
                               3      -1       181
                                       1        44
secreted_cytokine              1      -1       223
                                       1       131
                               2      -1       300
                                       1        64
                               3      -1       223
                                       1        49
dtype: int64

In [30]:
assert (df_gl['split'] != SPLIT_INFER).all()

In [31]:
# Finally, show candidate distribution across splits which should have many extra unlabeled
# candidates on the dev split 
pd.DataFrame(session.query(Candidate.type, Candidate.split).all()).groupby(['type', 'split']).size()

type                           split
inducing_cytokine              0        10000
                               1         1025
                               2          371
                               3          278
                               9        80165
inducing_transcription_factor  0        10000
                               1          627
                               2          300
                               3          225
                               9        25383
secreted_cytokine              0        10000
                               1         1019
                               2          364
                               3          272
                               9        80184
dtype: int64

### Train Candidates

In [33]:
df_tr_cand = pd.DataFrame(session.query(Candidate.id, Candidate.type, Candidate.split).filter(Candidate.split == SPLIT_INFER))
assert (df_tr_cand['split'] == SPLIT_INFER).all()
df_tr_cand.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 185732 entries, 0 to 185731
Data columns (total 3 columns):
id       185732 non-null int64
type     185732 non-null object
split    185732 non-null int64
dtypes: int64(2), object(1)
memory usage: 4.3+ MB


In [34]:
df_tr_cand.groupby(['type', 'split']).size()

type                           split
inducing_cytokine              9        80165
inducing_transcription_factor  9        25383
secreted_cytokine              9        80184
dtype: int64

In [35]:
# Sample 10k candidates for each class
df_tr_mv = df_tr_cand.groupby('type', group_keys=False).apply(lambda g: g.sample(n=10000, random_state=TCRE_SEED))
df_tr_mv.groupby(['type', 'split']).size()

type                           split
inducing_cytokine              9        10000
inducing_transcription_factor  9        10000
secreted_cytokine              9        10000
dtype: int64

In [36]:
tr_ct = session.query(Candidate).filter(Candidate.split == SPLIT_TRAIN).count() 
tr_ct

30000

In [37]:
if tr_ct == 0:
    for i, r in df_tr_mv.iterrows():
        cand = session.query(Candidate).filter(Candidate.id == int(r['id'])).one()
        cand.split = SPLIT_TRAIN
    session.commit()    
else:
    print('Skipping TRAIN split candidate move (they already exist)')

Skipping TRAIN split candidate move (they already exist)


In [2]:
# Finally, show distribution across splits
pd.DataFrame(session.query(Candidate.type, Candidate.split).all()).groupby(['type', 'split']).size().unstack()

split,0,1,2,3,9
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
inducing_cytokine,10000,1025,371,278,566192
inducing_transcription_factor,10000,627,300,225,158977
secreted_cytokine,10000,1019,364,272,566211
