## Document Loader

In [1]:
# * Be careful not to import anything related to Snorkel here as it will crash pyarrow on staging step *
import os
import pandas as pd
import numpy as np
import dotenv
import shutil
import glob
import tqdm
from tcre.env import *
import os.path as osp

In [2]:
# from sqlalchemy.orm import sessionmaker
# from snorkel import SnorkelSession
# session = SnorkelSession()

## Document Staging

In [3]:
df = pd.concat([
    pd.read_parquet(osp.join(IMPORT_DATA_DIR_02, 'corpus_01.parquet')).assign(src='entrez'),
    pd.read_parquet(osp.join(IMPORT_DATA_DIR_03, 'corpus_03.parquet')).assign(src='pmcoa')
], sort=True)
assert df['id_pmc'].notnull().all()
df = df.drop_duplicates(subset=['src', 'id_pmc'])
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 69135 entries, 0 to 48693
Data columns (total 18 columns):
abstract          67211 non-null object
arch_archive      48685 non-null object
arch_id           48685 non-null object
arch_name         48685 non-null object
arch_path         48685 non-null object
arch_venue        48685 non-null object
body              58252 non-null object
date_accepted     69135 non-null object
date_pub          69135 non-null object
date_received     69135 non-null object
id_doi            65905 non-null object
id_pmc            69135 non-null object
id_pmid           68306 non-null object
journal_ids       69135 non-null object
journal_titles    69135 non-null object
src               69135 non-null object
text              20450 non-null object
title             69133 non-null object
dtypes: object(18)
memory usage: 10.0+ MB


In [4]:
# Show document intersection
cts = df.assign(ind=1).pivot(index='id_pmc', columns='src', values='ind').fillna(0).astype(int)
cts.groupby(['entrez', 'pmcoa']).size().rename('count').reset_index()

Unnamed: 0,entrez,pmcoa,count
0,0,1,43824
1,1,0,15589
2,1,1,4861


In [5]:
(cts > 0).all(axis=1).value_counts()

False    59413
True      4861
dtype: int64

In [6]:
dupe_ids = cts[(cts > 0).all(axis=1)].index.values

In [7]:
# Prefer entrez doc for duplicates (make docs unique)
dff = df[(df['src'] == 'entrez') | ~df['id_pmc'].isin(dupe_ids)]
assert dff['id_pmc'].is_unique
dff.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 64274 entries, 0 to 48693
Data columns (total 18 columns):
abstract          62360 non-null object
arch_archive      43824 non-null object
arch_id           43824 non-null object
arch_name         43824 non-null object
arch_path         43824 non-null object
arch_venue        43824 non-null object
body              53411 non-null object
date_accepted     64274 non-null object
date_pub          64274 non-null object
date_received     64274 non-null object
id_doi            61191 non-null object
id_pmc            64274 non-null object
id_pmid           63447 non-null object
journal_ids       64274 non-null object
journal_titles    64274 non-null object
src               64274 non-null object
text              20450 non-null object
title             64272 non-null object
dtypes: object(18)
memory usage: 9.3+ MB


In [8]:
# REMOVE: This will temporarily filter to only annotated documents and a small sample of all others
# See: misc/id-set-generator.ipynb
dfid = pd.read_csv(osp.join(DATA_DIR, 'idgroups', 'idgrp_small.csv'))
is_annotated = ('PMC' + dff['id_pmc']).isin(dfid['id'].values)
dff = pd.concat([
    dff[is_annotated],
    dff[~is_annotated].groupby('src', group_keys=False).apply(lambda g: g.sample(n=1000, random_state=TCRE_SEED))
])
dff.groupby('src').size()

src
entrez    2635
pmcoa     1886
dtype: int64

In [9]:
dff['src'].unique()

array(['entrez', 'pmcoa'], dtype=object)

In [10]:
def stage_docs(df, batch_size=500, staging_dir='/tmp/corpus_staging'):
    if osp.exists(staging_dir):
        shutil.rmtree(staging_dir)
    os.makedirs(staging_dir)

    n_batches = len(df) // batch_size
    batches = np.array_split(np.arange(len(df)), n_batches)
    for i, batch in tqdm.tqdm(enumerate(batches), total=len(batches)):
        path = osp.join(staging_dir, 'B{:05d}.feather'.format(i))
        df.iloc[batch].reset_index(drop=True).to_feather(path)

#stage_docs(dff[dff['src'] == 'entrez'])
stage_docs(dff)

100%|██████████| 9/9 [00:00<00:00, 13.44it/s]


### Define Jobs To Process

In [11]:
from snorkel import SnorkelSession
from snorkel.parser import CorpusParser
from snorkel.models import Document, Sentence, Candidate
from dask.distributed import Client, progress
from tcre import processing
from tcre import supervision
import dask
import logging

In [12]:
def get_jobs(staging_dir='/tmp/corpus_staging'):
    jobs = []
    for f in os.listdir(staging_dir):
        i = int(f.split('.')[0][1:])
        jobs.append((i, osp.join(staging_dir, f)))
    return sorted(jobs)
jobs = get_jobs()
len(jobs)

9

In [12]:
# jobs = [jobs[i] for i in [  
#     0,   1,   2,   3,   4,   5,   6,   8,   9,  11,  12,  13,  14,
# ]]
# len(jobs)

### Clear Existing Contexts

In [13]:
def get_doc_ct():
    session = SnorkelSession()
    ct = session.query(Document).count()
    session.close()
    return ct
    
def clear_documents():
    session = SnorkelSession()
    parser = CorpusParser(parser=lambda v: None)
    parser.clear(session)
    session.commit()
    session.close()
    
n = get_doc_ct()
clear_documents()
print('Num docs before =', n, 'after =', get_doc_ct())

Num docs before = 0 after = 0


## Document Import

In [14]:
client = Client(
    threads_per_worker=1, n_workers=12, processes=True, 
    memory_limit='256GB', direct_to_workers=True, 
    silence_logs=logging.INFO
)
client

distributed.scheduler - INFO - Clear task state
distributed.scheduler - INFO -   Scheduler at:     tcp://127.0.0.1:34689
distributed.scheduler - INFO - Clear task state
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:33015'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:46731'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:37205'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:41233'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:36377'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:44125'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:38563'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:35957'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:46483'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:40481'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:33

0,1
Client  Scheduler: tcp://127.0.0.1:34689,Cluster  Workers: 12  Cores: 12  Memory: 3.07 TB


In [15]:
def process(job):
    import logging
    logging.basicConfig()
    batch_id, batch_file = job
    logging.info('Processing job %s (%s)', batch_id, batch_file)
    loader = processing.DocLoader(batch_file)
    loader.run(limit=None)

In [16]:
# Time estimates:
# docs=20450 -> time=2hr 22min
futures = client.map(process, jobs)
dask.distributed.progress(futures, notebook=True)

VBox()

In [17]:
_ = dask.distributed.wait(futures)

In [18]:
assert all([f.status == 'finished' for f in futures])

In [19]:
np.array([f.status == 'finished' for f in futures]).sum()

9

In [20]:
np.argwhere([f.status != 'finished' for f in futures]).squeeze()

array([], dtype=int64)

In [21]:
client.shutdown()

distributed.scheduler - INFO - Scheduler closing...
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:44587
distributed.core - INFO - Removing comms to tcp://172.17.0.2:44587
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:33607
distributed.core - INFO - Removing comms to tcp://172.17.0.2:33607
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:33471
distributed.core - INFO - Removing comms to tcp://172.17.0.2:33471
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:46221
distributed.core - INFO - Removing comms to tcp://172.17.0.2:46221
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:39595
distributed.core - INFO - Removing comms to tcp://172.17.0.2:39595
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:46619
distributed.core - INFO - Removing comms to tcp://172.17.0.2:46619
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:38661
distributed.core - INFO - Removing comms to tcp://172.17.0.2:38661
dis

In [22]:
session = SnorkelSession()
session.query(Document).count()

4521