In [97]:
import numpy as np
import os
from datasets import Dataset
import re

In [2]:
root_txt = 'data/cadec/text/'
root_ann = 'data/cadec/original/'

# Data pre-processing

## Functions

In [21]:
def read_files(fn):
    med, i = re.findall(r'(\w+)\.(\d+)\.txt', fn)[0]
    i = int(i)

    with open(os.path.join(root_txt, fn), 'r') as infile:
        text = infile.readlines()
        text = ''.join(text)
    with open(os.path.join(root_ann, fn.replace('txt', 'ann')), 'r') as infile:
        annotations = infile.readlines()
        annotations = [l.strip() for l in annotations if not l.startswith('#')]
    return i, med, text, annotations
    

def parse_annotations(lines):
    annots = {}
    for i in range(len(lines)):
        annots[i] = {}
        entity = re.findall(r'(Finding|ADR|Drug|Disease|Symptom) ([\d; ]+)\t(.*)$', 
                            lines[i])[0]
        annots[i]['ner'] = entity[0]
        boundaries = entity[1].split(';')
        boundaries = [[int(bb) for bb in b.split()] for b in boundaries]
        annots[i]['boundaries'] = boundaries
        annots[i]['text'] = entity[2]
    return annots


def get_current_annot(annots, idx, start):
    if idx >= len(annots):
        return idx - 1
    boundaries = annots[idx]['boundaries']
    if start > boundaries[-1][-1]:
        return get_current_annot(annots, idx+1, start)
    return idx


def print_tags_tokens(data):
    tokens = data['tokens']
    tags = data['ner']
    line1 = ""
    line2 = ""
    for word, label in zip(tokens, tags):
        max_length = max(len(word), len(label))
        line1 += word + " " * (max_length - len(word) + 1)
        line2 += label + " " * (max_length - len(label) + 1)
    print(line1)
    print(line2)


def get_IOB_tags(text: str, annotations: dict):
    tokens = re.findall(r'\w+|[^\w\s]', text)
    if len(annotations) == 0:
        return tokens, ['O' for _ in tokens]
    offset = 0
    idx = 0
    tags = []
    text_tmp = text
    for token in tokens:
        span = np.asarray(re.search(re.escape(token), text_tmp).span())
        idx = get_current_annot(annotations, idx, span[0] + offset)
        boundaries = annotations[idx]['boundaries']
        found = False
        for i, (start, end) in enumerate(boundaries):
            if (span[0] + offset >= start) and (span[1] + offset <= end):
                prefix = 'B-'
                if i > 0 or span[0] + offset > start:
                    prefix = 'I-'
                tags.append(prefix + annotations[idx]['ner'])
                found = True
                break
        if not found:
            tags.append('O')
        offset += span[1]
        text_tmp = text_tmp[span[1]:]
    return tokens, tags

## Parsing data into the desired format

In [51]:
data = {}

for fn in os.listdir(root_txt):
    i, med, text, annotations = read_files(fn)
    annots = parse_annotations(annotations)
    tokens, tags = get_IOB_tags(text, annots)
    if med not in data.keys():
        data[med] = {}
    data[med][i] = {
        'tokens': tokens, 'ner': tags
    }
    

## Creating datasets

In [None]:
raw_data = {}
idx = 0
tt_split = 0.8 # train/test split - should maybe add validation?
for v in data.values():
    for vv in v.values():
        raw_data[idx] = vv
        idx += 1

indices = np.arange(len(raw_data))
np.random.shuffle(indices)
train_idx = indices[:int(tt_split * len(indices))]
test_idx = indices[int(tt_split * len(indices)):]

data_split = {'train': {}, 'test': {}}
for idx in train_idx:
    for k in raw_data[idx].keys():
        if k not in data_split['train']:
            data_split['train'][k] = []
        data_split['train'][k].append(raw_data[idx][k])

for idx in test_idx:
    for k in raw_data[idx].keys():
        if k not in data_split['test']:
            data_split['test'][k] = []
        data_split['test'][k].append(raw_data[idx][k]) 

datasets = {k: Dataset.from_dict(data_split[k]) for k in data_split.keys()}
