# All functions about building the data and building the knowledge base here are taken from Medlinker github
https://github.com/danlou/MedLinker/blob/master/scripts/

In [1]:
!pip install scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_md-0.4.0.tar.gz

Collecting scispacy
  Downloading scispacy-0.4.0-py3-none-any.whl (44 kB)
[K     |████████████████████████████████| 44 kB 119 kB/s eta 0:00:01
[?25hCollecting conllu
  Downloading conllu-4.4-py2.py3-none-any.whl (15 kB)
Collecting pysbd
  Downloading pysbd-0.3.4-py3-none-any.whl (71 kB)
[K     |████████████████████████████████| 71 kB 831 kB/s eta 0:00:01
[?25hCollecting nmslib>=1.7.3.6
  Downloading nmslib-2.1.1-cp37-cp37m-manylinux2010_x86_64.whl (13.5 MB)
[K     |████████████████████████████████| 13.5 MB 3.5 MB/s eta 0:00:01
Collecting spacy<3.1.0,>=3.0.0
  Downloading spacy-3.0.5-cp37-cp37m-manylinux2014_x86_64.whl (12.8 MB)
[K     |████████████████████████████████| 12.8 MB 7.5 MB/s eta 0:00:01
Collecting pybind11<2.6.2
  Downloading pybind11-2.6.1-py2.py3-none-any.whl (188 kB)
[K     |████████████████████████████████| 188 kB 11.8 MB/s eta 0:00:01
Collecting catalogue<2.1.0,>=2.0.1
  Downloading catalogue-2.0.1-py3-none-any.whl (9.6 kB)
Collecting srsly<3.0.0,>=2.4.0
  Downlo

In [3]:
"""
This script expects you've followed the instructions in https://github.com/chb/py-umls to install UMLS.
"""

import json
from collections import Counter
import sqlite3

from scispacy.umls_semantic_type_tree import construct_umls_tree_from_tsv

umls_tree = construct_umls_tree_from_tsv('../input/preprocessing/umls_semantic_type_tree.tsv')  # change to your location

umls_db_path = '../input/umls-database/umls.db'  # change to your location
conn = sqlite3.connect(umls_db_path)
c = conn.cursor()

cui_data = {}
source_counter = Counter()
def_mismatches = set()


st21pv_sources = set(['CPT', 'FMA', 'GO', 'HGNC', 'HPO', 'ICD10',
                      'ICD10CM', 'ICD9CM', 'MDR', 'MSH', 'MTH',
                      'NCBI', 'NCI', 'NDDF', 'NDFRT', 'OMIM',
                      'RXNORM', 'SNOMEDCT_US'])


st21pv_types = set(['T005', 'T007', 'T017', 'T022', 'T031', 'T033', 'T037', 
                    'T038', 'T058', 'T062', 'T074', 'T082', 'T091', 'T092', 
                    'T097', 'T098', 'T103', 'T168', 'T170', 'T201', 'T204'])


st21pv_types_children = {}
for st in st21pv_types:
    st_node = umls_tree.get_node_from_id(st)
    st_children = set([ch.type_id for ch in umls_tree.get_children(st_node)])
    st21pv_types_children[st] = st_children


RESTRICT_ST21PV = False
NO_DEFS = False

print('Collecting info from \'descriptions\' table ...')
for row_idx, row in enumerate(c.execute('SELECT * FROM descriptions')):
    
    CUI, LAT, SAB, TTY, STR, STY = row

    source_counter[SAB] += 1


    STY = STY.split('|')
    if LAT != 'ENG':
        continue

    if RESTRICT_ST21PV:
        if SAB not in st21pv_sources:
            continue

        valid_row_sts = []
        for row_st in STY:
            if row_st in st21pv_types:
                valid_row_sts.append(row_st)
            
            else:
                for st in st21pv_types:
                    if row_st in st21pv_types_children[st]:
                        valid_row_sts.append(st)  # not row_st !
                        break

        if len(valid_row_sts) == 0:
            continue
        else:
            STY = valid_row_sts

        if len(st21pv_types.intersection(set(STY))) == 0:
            continue

    if CUI not in cui_data:
        CUI_info = {}
        CUI_info['SAB'] = SAB
        # CUI_info['TTY'] = TTY
        CUI_info['STY'] = STY

        if NO_DEFS is False:
            CUI_info['DEF'] = []
        
        CUI_info['STR'] = [STR]
        CUI_info['Name'] = '' # custom

        cui_data[CUI] = CUI_info
    
    else:
        cui_data[CUI]['STR'].append(STR)

    # source_counter[SAB] += 1

print('# CUIs:', len(cui_data))

if NO_DEFS is False:
    print('Collecting info from \'MRDEF\' table ...')
    for row_idx, row in enumerate(c.execute('SELECT * FROM MRDEF')):    
        CUI, AUI, ATUI, SATUI, SAB, DEF, SUPPRESS, CVF = row
        
        if CUI in cui_data:
            cui_data[CUI]['DEF'].append(DEF)
        else:
            def_mismatches.add(CUI)


print('Preprocessing data ...')
for cui in cui_data.keys():
    cui_data[cui]['Name'] = cui_data[cui]['STR'][0]
    cui_data[cui]['STR'] = list(set(cui_data[cui]['STR'][1:]))


print('Storing data as JSON ...')

fn = 'umls.2020AA.active'
if RESTRICT_ST21PV:
    fn += '.st21pv'
else:
    fn += '.full'

if NO_DEFS:
    fn += '.no_defs'

fn += '.json'

with open(fn, 'w') as json_f:
    json.dump(cui_data, json_f)

Collecting info from 'descriptions' table ...
# CUIs: 4412440
Collecting info from 'MRDEF' table ...
Preprocessing data ...
Storing data as JSON ...


In [2]:
"""
utils for reading MedMentions original format
adapted from scispacy: https://github.com/allenai/scispacy
"""

from typing import NamedTuple, List, Iterator, Dict, Tuple
import tarfile
import atexit
import os
import shutil
import tempfile

from scispacy.file_cache import cached_path

from scispacy.umls_semantic_type_tree import construct_umls_tree_from_tsv
umls_tree = construct_umls_tree_from_tsv("../input/preprocessing/umls_semantic_type_tree.tsv")


class MedMentionEntity(NamedTuple):
    start: int
    end: int
    mention_text: str
    mention_type: str
    umls_id: str

class MedMentionExample(NamedTuple):
    title: str
    abstract: str
    text: str
    pubmed_id: str
    entities: List[MedMentionEntity]


def process_example(lines: List[str]) -> MedMentionExample:
    """
    Processes the text lines of a file corresponding to a single MedMention abstract,
    extracts the title, abstract, pubmed id and entities. The lines of the file should
    have the following format:
    PMID | t | Title text
    PMID | a | Abstract text
    PMID TAB StartIndex TAB EndIndex TAB MentionTextSegment TAB SemanticTypeID TAB EntityID
    ...
    """
    pubmed_id, _, title = [x.strip() for x in lines[0].split("|", maxsplit=2)]
    _, _, abstract = [x.strip() for x in lines[1].split("|", maxsplit=2)]

    entities = []
    for entity_line in lines[2:]:
        _, start, end, mention, mention_type, umls_id = entity_line.split("\t")
        # mention_type = mention_type.split(",")[0]
        mention_type = max(mention_type.split(","), key=lambda x: umls_tree.get_node_from_id(x).level)
        entities.append(MedMentionEntity(int(start), int(end),
                                         mention, mention_type, umls_id))

    # compose text from title and abstract
    text = title + ' ' + abstract

    return MedMentionExample(title, abstract, text, pubmed_id, entities)

def med_mentions_example_iterator(filename: str) -> Iterator[MedMentionExample]:
    """
    Iterates over a MedMentions file, yielding examples.
    """
    with open(filename, "r") as med_mentions_file:
        lines = []
        for line in med_mentions_file:
            line = line.strip()
            if line:
                lines.append(line)
            else:
                yield process_example(lines)
                lines = []
        # Pick up stragglers
        if lines:
            yield process_example(lines)

# def read_med_mentions(filename: str):
#     """
#     Reads in the MedMentions dataset into Spacy's
#     NER format.
#     """
#     examples = []
#     for example in med_mentions_example_iterator(filename):
#         # spacy_format_entities = [(x.start, x.end, x.mention_type) for x in example.entities]
#         spacy_format_entities = [(x.start, x.end, x.mention_text, x.mention_type, x.umls_id) for x in example.entities]
#         examples.append((example.text, {"entities": spacy_format_entities}))

#     return examples


def read_full_med_mentions(directory_path: str,
                           label_mapping: Dict[str, str] = None,
                           span_only: bool = False):

    def _cleanup_dir(dir_path: str):
        if os.path.exists(dir_path):
            shutil.rmtree(dir_path)

    resolved_directory_path = cached_path(directory_path)
    if "tar.gz" in directory_path:
        # Extract dataset to temp dir
        tempdir = tempfile.mkdtemp()
        print(f"extracting dataset directory {resolved_directory_path} to temp dir {tempdir}")
        with tarfile.open(resolved_directory_path, 'r:gz') as archive:
            archive.extractall(tempdir)
        # Postpone cleanup until exit in case the unarchived
        # contents are needed outside this function.
        atexit.register(_cleanup_dir, tempdir)

        resolved_directory_path = tempdir

    expected_names = ["corpus_pubtator.txt",
                      "corpus_pubtator_pmids_all.txt",
                      "corpus_pubtator_pmids_dev.txt",
                      "corpus_pubtator_pmids_test.txt",
                      "corpus_pubtator_pmids_trng.txt"]

    corpus = os.path.join(resolved_directory_path, expected_names[0])
    examples = med_mentions_example_iterator(corpus)

    train_ids = {x.strip() for x in open(os.path.join(resolved_directory_path, expected_names[4]))}
    dev_ids = {x.strip() for x in open(os.path.join(resolved_directory_path, expected_names[2]))}
    test_ids = {x.strip() for x in open(os.path.join(resolved_directory_path, expected_names[3]))}

    train_examples = []
    dev_examples = []
    test_examples = []

    for example in examples:
        if example.pubmed_id in train_ids:
            train_examples.append(example)

        elif example.pubmed_id in dev_ids:
            dev_examples.append(example)

        elif example.pubmed_id in test_ids:
            test_examples.append(example)

    return train_examples, dev_examples, test_examples


############################################

import itertools
import json


class MedMentionSentenceEntity(NamedTuple):
    cui: str
    st: str
    tokens: List[str]
    start: int
    end: int


def iterate_annotations(sci_nlp, dataset_examples):

    for ex in dataset_examples:

        # get sentence positions to delimit annotations to sentences
        sent_span_idxs = []
        text = sci_nlp(ex.text)
        sents = list(text.sents)

        ch_idx = 0
        # first sent will include title (due to composition expected by start/end ent indices)
        # need to handle first sent differently
        sent = sents.pop(0)
        
        # start by adding title as first sentence
        sent_span_idxs.append((0, len(ex.title)))

        # add remaining as another sentence (if any left)
        if len(sent.text) > len(ex.title) + 1:
            sent_span_idxs.append((len(ex.title) + 1, len(sent.text)))

        ch_idx += len(sent.text) + 1

        for sent in sents:
            start_idx = ch_idx
            end_idx = ch_idx + len(sent.text)
            sent_span_idxs.append((start_idx, end_idx))

            if text[end_idx] != ' ':
                ch_idx = end_idx + 1  # ws separating sentences
        # ch_idx -= 1  # fix last added ws

        for ent in ex.entities:

            # sanity check 1 - mentions match in text
            text_mention_extraction = ex.text[ent.start:ent.end]
            assert ent.mention_text == text_mention_extraction

            for sent_start, sent_end in sent_span_idxs:
                if (ent.start >= sent_start) and (ent.end <= sent_end):
                    sent = ex.text[sent_start:sent_end]

                    # adjust start and end positions
                    ent = MedMentionEntity(ent.start - sent_start,
                                           ent.end - sent_start,
                                           ent.mention_text,
                                           ent.mention_type,
                                           ent.umls_id)

                    # sanity check 2 - mentions match in sentence
                    sent_mention_extraction = sent[ent.start:ent.end]
                    assert ent.mention_text == sent_mention_extraction

                    yield (ent, sent)


# def locate_tokens(all_tokens, subset_tokens):
#     """
#     Returns a list of indices (LoL) for all mention tokens within a list of tokens (i.e. sentence tokens).
#     """
#     # tests all combinations, very slow and fails for long spans
#     # gets the job done for now, to be improved later

#     def get_idxs(elems, e):  # assumes must occurr
#         return [i for i, e_ in enumerate(elems) if e == e_]

#     def is_linear(elems):
#         # return elems == [elems[0] + i for i in range(len(elems))]
#         return all(e1 == e2 - 1 for e1, e2 in zip(elems, elems[1:]))

#     # method isn't tractable for very long lists (also very rare)
#     if len(subset_tokens) > 10:
#         return [-1]

#     all_possible_idxs = []  # indices for overlaps between all_tokens and subset
#     for token in subset_tokens:
#         if token in all_tokens:
#             all_possible_idxs.append(get_idxs(all_tokens, token))
    
#     if len(all_possible_idxs) > 0:
#         for combination in itertools.product(*all_possible_idxs):
#             combination = list(combination)
#             if is_linear(combination):  # only want indices increasing by +1
#                 return combination
    
#     return [-1]


def locate_tokens(all_tokens, subset_tokens, reserved_spans=set()):

    def get_idxs(elems, e):
        return [i for i, e_ in enumerate(elems) if e == e_]

    for t0_idx in get_idxs(all_tokens, subset_tokens[0]):
        shift_idx = t0_idx + len(subset_tokens)
        if all_tokens[t0_idx:shift_idx] == subset_tokens:
            start = t0_idx
            end = shift_idx - 1

            if (start, end) not in reserved_spans:
                return [start, end]
    
    return [-1]


def get_sent_boundaries(sci_nlp, text, title):
    """
    Returns char indices for start and end of sentences from the full text.
    The title is concatenated with the text, needs to processed as first sentence.
    """

    # start with scispacy's sentence splitting
    sents = [sent.text for sent in sci_nlp(text).sents]

    sent_span_idxs = []

    ch_idx = 0
    # first sent will include title (due to composition expected by start/end ent indices)
    # need to handle first sent differently
    sent = sents.pop(0)
    
    # start by adding title as first sentence
    sent_span_idxs.append((0, len(title) - 1))

    # add remaining as another sentence (if any left)
    if len(sent) > len(title) + 1:
        sent_span_idxs.append((len(title) + 1, len(sent) - 1))

    ch_idx += (len(sent) - 1) + 2  # skip over ws to next char, len gives +1

    for sent in sents:
        start_idx = ch_idx
        end_idx = ch_idx + (len(sent) - 1)

        # move to next char, skips ws
        try:
            if text[end_idx + 1] == ' ':
                ch_idx = end_idx + 2
            else:  # happens when sentence splitting fails
                ch_idx = end_idx + 1
        except IndexError:  # end of text
            ch_idx = end_idx
        
        sent_span_idxs.append((start_idx, end_idx))

    return sent_span_idxs


def get_sent_ents(sci_nlp, sent_tokens, sent_start, sent_end, doc_entities):

    sent_ents = []
    reserved_spans = set()
    skipped_mentions = 0  # failed locating mention
    for ent in doc_entities:
        # only interested in entities located within sentence boundaries
        if (ent.start >= sent_start) and (ent.end <= sent_end):

            mention_tokens = [tok.text for tok in sci_nlp(ent.mention_text)]
            mention_tokens_idxs = locate_tokens(sent_tokens, mention_tokens, reserved_spans)

            if -1 in mention_tokens_idxs:
                skipped_mentions += 1  # something may have gone wrong with splitting
                continue
            
            mention_token_start = mention_tokens_idxs[0]
            mention_token_end = mention_tokens_idxs[-1] + 1  # +1 for easier slicing... not sure about this choice

            if (mention_token_start, mention_token_end) not in reserved_spans:  # no overlapping/duplicate spans

                sent_ent = MedMentionSentenceEntity(cui=ent.umls_id,
                                                    st=ent.mention_type,
                                                    tokens=mention_tokens,
                                                    start=mention_token_start,
                                                    end=mention_token_end)

                sent_ents.append(sent_ent)
                reserved_spans.add((mention_token_start, mention_token_end))

    return sent_ents, skipped_mentions


# def iterate_docs_converted(split_path):

#     # load json dataset
#     with open(split_path, 'r') as json_f:
#         dataset = json.load(json_f)

#     for doc in dataset['docs']:
#         yield doc

In [5]:
import json
import logging
from time import time

import spacy

sci_nlp = spacy.load('en_core_sci_md')


logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%d-%b-%y %H:%M:%S')


mm_splits = {'train':[], 'dev': [], 'test': []}
# mm_splits['train'], mm_splits['dev'], mm_splits['test'] = read_full_med_mentions('data/MedMentions/full/data/')
mm_splits['train'], mm_splits['dev'], mm_splits['test'] = read_full_med_mentions('../input/thesis/')

logging.info('Processing Instances ...')

for split_label in ['dev', 'test', 'train']:
    split_data = {'split': split_label, 'timestamp': int(time()), 'n_unlocated_mentions': 0, 'n_located_mentions': 0, 'docs': []}
    instances = mm_splits[split_label]
    for doc_idx, ex in enumerate(instances):

        if doc_idx % 100 == 0:
            logging.info('[%s] Converted %d/%d instances.' % (split_label, doc_idx, len(instances)))

        doc = {}
        doc['idx'] = doc_idx
        doc['title'] = ex.title
        doc['abstract'] = ex.abstract
        doc['text'] = ex.text
        doc['pubmed_id'] = ex.pubmed_id
        doc['sentences'] = []

        # get sentence positions to delimit annotations to sentences
        sent_span_idxs = get_sent_boundaries(sci_nlp, ex.text, ex.title)

        for sent_start, sent_end in sent_span_idxs:
            sent = {}

            sent_text = ex.text[sent_start:sent_end + 1]
            sent_tokens = [tok.text.strip() for tok in sci_nlp(sent_text)]
            sent_tokens = [tok for tok in sent_tokens if tok != '']  # ensure no ws

            sent['text'] = sent_text
            sent['start'] = sent_start
            sent['end'] = sent_end
            sent['tokens'] = sent_tokens

            # get gold ents
            gold_ents, n_sent_skipped_mentions = get_sent_ents(sci_nlp, sent_tokens, sent_start, sent_end, ex.entities)

            sent['n_unlocated_mentions'] = n_sent_skipped_mentions
            split_data['n_unlocated_mentions'] += n_sent_skipped_mentions

            sent['spans'] = []
            for mm_entity in gold_ents:
                ent = {}
                ent['cui'] = mm_entity.cui
                ent['st'] = mm_entity.st
                ent['tokens'] = mm_entity.tokens
                ent['start'] = mm_entity.start
                ent['end'] = mm_entity.end
                sent['spans'].append(ent)

            split_data['n_located_mentions'] += len(sent['spans'])
            doc['sentences'].append(sent)

        split_data['docs'].append(doc)

    logging.info('[%s] Writing converted MedMentions ...' % split_label)
    with open('mm_converted.%s.json' % split_label, 'w') as json_f:
        json.dump(split_data, json_f, sort_keys=True, indent=4)

Here


In [44]:
import sys
import logging
import json

logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%d-%b-%y %H:%M:%S')


def iterate_docs_converted(split_path):

    # load json dataset
    with open(split_path, 'r') as json_f:
        dataset = json.load(json_f)

    for doc in dataset['docs']:
        yield doc


if __name__ == '__main__':

    """Change those here"""
    specify_st = True
    split_label = 'dev'
    
    mm_path = './mm_converted.%s.json' % split_label

    logging.info('Loading MedMentions - %s ...' % mm_path)
    mm_docs = list(iterate_docs_converted(mm_path))

    conll_lines = []

    logging.info('Processing Instances ...')
    for doc_idx, doc in enumerate(mm_docs):

        conll_lines.append('-DOCSTART- (%s)' % doc['pubmed_id'])
        conll_lines.append('')

        for sent in doc['sentences']:

            tokens = sent['tokens']
            tags = ['O' for t in tokens]

            for ent in sent['spans']:
                
                if specify_st:
                    tag = ent['st']
                else:
                    tag = 'Entity'

                if len(ent['tokens']) == 1:
                    marker = 'B'
                    tags[ent['start']] = '%s-%s' % (marker, tag)
                
                else:
                    B_added = False
                    for tag_idx in range(ent['start'], ent['end']):
                        if not B_added:
                            marker = 'B'
                            B_added = True
                        else:
                            marker = 'I'

                        tags[tag_idx] = '%s-%s' % (marker, tag)
            
            for token, tag in zip(tokens, tags):
                conll_lines.append('%s\tO\tO\t%s' % (token, tag))
            conll_lines.append('')


    if specify_st:
        filepath = 'mm_ner_sts.%s.conll' % split_label
    else:
        filepath = 'mm_ner_ent.%s.conll' % split_label

    logging.info('Writing CONLL - %s ...' % filepath)
    with open(filepath, 'w') as f:
        for line in conll_lines:
            f.write('%s\n' % line)
    

In [None]:
#text = " ".join(train_sentences[0])
#nlp = en_core_sci_sm.load()
#nlp.add_pipe("scispacy_linker", config={"resolve_abbreviations": True, "linker_name": "umls"})
#doc = nlp(text)
#entity = doc.ents[3]

#print("Name: ", entity)

#linker = nlp.get_pipe("scispacy_linker")

#for umls_ent in entity._.kb_ents:
#    print(umls_ent)
#    print(linker.kb.cui_to_entity['C0010674'])

#db = linker.kb.cui_to_entity
#db['C0857937']




#import nltk
#for a in b:
#    tokens = nltk.word_tokenize(a)
#    tokens = [token.lower() for token in tokens if len(token) > 1]
#    if len(tokens) == 1:
#        print(tokens, " That's unigram")