In [None]:
import re
import string
import glob
import gzip
import xml.etree.ElementTree as ET
from multiprocessing import Pool
from joblib import Parallel, delayed

"""If your corpus has multiple documents, the documents (only) should be
separated by new line characters. Cooccurrence contexts for words do not
extend past newline characters."""

def preprocessing(path, corpus_file='corpus_file.txt'):
    """
    Creates corpus based on articles from annual baseline of Medline/Pubmed Database
    
    Parameters
    ----------
    path : string
        Path to folder with files .gz from pubmed
    corpus_file: string
        Name of file, in which corpus will be created
    """
    with open(corpus_file, 'w+') as corpus, Pool(multiprocessing.cpu_count()) as p:
        for article in p.imap(process_article_set, glob.glob('{}*.gz'.format(path))):
            corpus.write(article)


def process_article_set(file):
    translator = str.maketrans('', '', string.punctuation)
    with gzip.open(file) as xml_file:
        try:
            article_set = ET.parse(xml_file).getroot()
            results = ''
            for article in article_set:
                results += process_article(article, translator)
            return results
        except ET.ParseError:
            return ''


def process_article(article, translator):
    title = article.find('MedlineCitation/Article/ArticleTitle')
    abstract = article.find('MedlineCitation/Article/Abstract/AbstractText')
    mesh_heading = process_mesh_heading(article.find('MedlineCitation/MeshHeadingList'))
    article_data = ''
    if title is not None:
        article_data += process_raw_text(title, translator)
    if abstract is not None:
        article_data += process_raw_text(abstract, translator)
    if mesh_heading is not None:
        article_data += process_raw_text(mesh_heading, translator)
    return '{}\n'.format(article_data.strip().replace('\n', ''))


def process_raw_text(data, translator):
    """Removes punctuation and uppercase from given string."""
    try:
        return ' '.join(data.text.lower().translate(translator).split())
    except AttributeError:
        return ''


def process_mesh_heading(data):
    """Reads meshheadinglist and returns names of descriptors."""
    return ' '.join(map(lambda x: x.text.lower(), data.findall(
        'MeshHeading/DescriptorName'))) if data is not None else ''


def length_of_local_context(path):
    """Finds the length of local context window, which fully covers every single article"""
    with open(path, 'r+') as corpus, Pool(multiprocessing.cpu_count()) as p:
        max_length = 0
        articles = Parallel(n_jobs=16, verbose=10, pre_dispatch='all')(
            delayed(lambda x: 1)(article) for article in corpus)
        progress = articles // 10
        for index, cur_length in enumerate(p.imap(count_words_in_string, corpus), start=1):
            if index % progress == 0:
                print('Processed {}/{} articles.')
            if cur_length > max_length:
                max_length = cur_length
        return max_length

def count_words_in_string(text):
    return len(text.split())