In [5]:
import sys
print(sys.executable)
import os
print(os.getcwd())

/home/nikcaryo/miniconda3/envs/amelie/bin/python
/home/nikcaryo/amelie-229/amelie


In [24]:
import argparse
import json
import multiprocessing
import os
import pickle
import random
import sys
from collections import defaultdict
import importlib
import constdb

import text_classification

In [34]:
def load_article(processed_dir, pmid):
    path = processed_dir + '/' + str(pmid) + '.pkl'

    if not os.path.isfile(path):
        return None
    else:
        with open(path, 'rb') as file:
            try:
                loaded = pickle.load(file)
                return loaded
            except EOFError:
                print("Cannot load pmid %s for some reason ..." % pmid, flush=True)
                return None

class ConvertToTextFunction:
    def __init__(self, processed_dir, replace_phenos_with_nothing):
        self.processed_dir = processed_dir
        self.replace_phenos_with_nothing = replace_phenos_with_nothing

    def __call__(self, pmid):
        processed_article = load_article(self.processed_dir, pmid)
        if processed_article is None:
            return None

        article = text_classification.convert_to_text(processed_article, use_main_text=True,
                                                      replace_phenos_with_nothing=self.replace_phenos_with_nothing)

        return pmid, article


def convert_all_to_text(processed_dir, pmids, replace_phenos_with_nothing):
    with multiprocessing.Pool(100) as pool:
        print('Pool created', flush=True)
        for i, item in enumerate(pool.imap_unordered(
                ConvertToTextFunction(processed_dir,
                                      replace_phenos_with_nothing=replace_phenos_with_nothing),
                pmids,
                chunksize=100)):
            if i % 10000 == 0:
                print('Processed ', i, ' out of ', len(pmids), flush=True)

            if item is not None:
                yield item

def shuffle_articles_labels(articles, labels):
    zipped_articles_labels = [x for x in zip(articles, labels)]
    random.shuffle(zipped_articles_labels)
    return [x for x in zip(*zipped_articles_labels)]

In [39]:
def train_text_field_classifier(out_dir, process_dir, field_name, save, cross_val, l1,
                                replace_phenos_with_nothing):
    with open(out_dir + '/dataset_meta.json') as file:
        dataset_info = json.load(file)

    with open(out_dir + '/omim.json') as file:
        omim = json.load(file)
    positives = set(dataset_info['positive_pmids'])
    print('Total positives: ', len(positives), flush=True)
    good_positives = []

    for pmid in positives:
        omim_data = omim[str(pmid)]
        if field_name not in omim_data:
            continue
        field = omim_data[field_name]
        if len(field) == 0 or len(field) == 2:
            # Skip bad ones
            continue
        good_positives.append(pmid)
    
    print('Good positives: ', len(good_positives), flush=True)
    articles = []
    labels = []

    for pmid, article in convert_all_to_text(process_dir, good_positives,
                                             replace_phenos_with_nothing=replace_phenos_with_nothing):
        omim_data = omim[str(pmid)]
        field = omim_data[field_name]
        articles.append(article)
        labels.append(field[0])

    for pmid, article in convert_all_to_text(process_dir, dataset_info['gwas_pmids'],
                                             replace_phenos_with_nothing=replace_phenos_with_nothing):
        articles.append(article)
        labels.append('gwas')

    articles, labels = shuffle_articles_labels(articles, labels)
    
    # remove for actual training!
    articles = articles[:100]
    labels = labels[:100]

    print('Have ', len(articles), flush=True)
    counters = defaultdict(int)
    for label in labels:
        counters[label] += 1

    print('Counts per label:', counters, flush=True)
    print('Done converting', flush=True)
    classifier = text_classification.create_model(articles, labels, cross_val=cross_val, l1=l1)

    if save:
        print("SAVING TEXT FIELD RELEVANCE CLASSIFIER", flush=True)
        with open(out_dir + '/text_field_{}.pkl'.format(field_name), 'wb') as out_file:
            pickle.dump(classifier, out_file)
    else:
        print("NOT SAVING TEXT FIELD RELEVANCE CLASSIFIER", flush=True)

In [40]:
out_dir = "amelie_out_dir"
process_dir = "amelie_process_dir"

save = True
cross_val = True
l1=False
replace_phenos_with_nothing=False

# right now it's set to only use 100 papers
# it also uses python multiprocessing, so it will use all the cores on your computer to train
# so it might slow your computer to a crawl but idk we'll see lol
train_text_field_classifier(out_dir, process_dir, "inheritance_modes", save, cross_val, l1,
                                replace_phenos_with_nothing)
train_text_field_classifier(out_dir, process_dir, "variant_types", save, cross_val, l1,
                                replace_phenos_with_nothing)

Total positives:  60160
Good positives:  11177
Pool created
Processed  0  out of  11177
Processed  10000  out of  11177
Pool created
Processed  0  out of  3264
Have  14141
Counts per label: defaultdict(<class 'int'>, {'dominant': 4620, 'recessive': 6343, 'gwas': 3178})
Done converting
Pipeline created
Five-fold cross validation
Scoring: micro


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


KeyboardInterrupt: 