In [1]:
from pathlib import Path
from string import punctuation
import numpy as np
from collections import defaultdict
from nltk.corpus import wordnet as wn

In [23]:
lines = Path('bc2gm/train_aug2.tsv').read_text().splitlines()

In [24]:
ALT_ENTITIES = {'PROTEIN': 'GENE', 'CL': 'CELL_TYPE', 'CHEBI': 'CHEMICAL', 'GGP': 'GENE', 'SPECIES': 'TAXON',
                'CELLLINE': 'CELL_LINE'}
all_tokens = defaultdict(set)
named_tokens = []
named_tokens_labels = []
words = []
sentences = []
labels = []
tags = []

complex_token = []
complex_token_label = []
for line in lines:
    pair = line.strip().split('\t')
    if len(pair) == 2:
        word, tag = pair
        words.append(word)
        if tag[2:].upper() in ALT_ENTITIES:
            tag = tag[0:2] + ALT_ENTITIES[tag[2:].upper()]
        tags.append(tag)
        if 'B-' in tag or 'I-' in tag or 'E-' in tag:
            complex_token.append(word)
            complex_token_label.append(tag)
            if 'E-' in tag:
                named_tokens.append(complex_token)
                named_tokens_labels.append(complex_token_label)
                complex_token = []
                complex_token_label = []
        elif 'S-' in tag:
            named_tokens.append([word])
            named_tokens_labels.append([tag])
        all_tokens[tag].add(word)
    else:
        sentences.append(words)
        labels.append(tags)
        words, tags = [], []

In [13]:
all_tokens = {key: list(value) for key, value in all_tokens.items()}

In [14]:
def labelwise_replace(sentence, labels):
    new_sentence = []
    for ind, label in enumerate(labels):
        p = np.random.uniform()
        if p >= 0.5:
            new_token = np.random.choice(all_tokens[label])
            new_sentence.append(new_token)
        else:
            new_sentence.append(sentence[ind])
    return new_sentence, labels

In [15]:
def mention_replace(sentence, labels):
    new_sent, new_lab = [], []
    for ind, (label, token) in enumerate(zip(labels, sentence)):
        if any(letter in label for letter in ['B-', 'S-']):
            p = np.random.uniform()
            if p >= 0.5:
                new_ind = np.random.randint(len(named_tokens))
                new_sent.extend(named_tokens[new_ind])
                new_lab.extend(named_tokens_labels[new_ind])
            else:
                if label.startswith('S-'):
                    new_sent.append(token)
                    new_lab.append(label)
                else:
                    ind += 1
                    while any(letter in label for letter in ['I-', 'E-']):
                        new_sent.append(sentence[ind])
                        new_lab.append(labels[ind])
                        ind += 1
        elif any(letter in label for letter in ['I-', 'E-']):
            continue
        else:
            new_sent.append(token)
            new_lab.append(label)
    return new_sent, new_lab

In [16]:
def segments_shuffle(sentence, labels):
    pieces_tokens, pieces_labels = [], []
    prev_label = None
    piece_tokens, piece_labels = [], []
    for token, label in zip(sentence, labels):
        if prev_label == 'O' and label != 'O':
            pieces_tokens.append(piece_tokens)
            pieces_labels.append(piece_labels)
            piece_tokens, piece_labels = [], []
        if 'B-' in label or 'I-' in label or 'E-' in label:
            piece_tokens.append(token)
            piece_labels.append(label)
            if 'E-' in label:
                pieces_tokens.append(piece_tokens)
                pieces_labels.append(piece_labels)
                piece_tokens, piece_labels = [], []
        elif 'S-' in label:
            pieces_tokens.append([token])
            pieces_labels.append([label])
        else:
            piece_tokens.append(token)
            piece_labels.append(label)
        
        prev_label = label
    if len(piece_tokens):
        pieces_tokens.append(piece_tokens)
        pieces_labels.append(piece_labels)
    temp = list(zip(pieces_tokens, pieces_labels))
    np.random.shuffle(temp)
    new_sent, new_lab = zip(*temp)
    return [item for sublist in new_sent for item in sublist], [item for sublist in new_lab for item in sublist]
                    

In [18]:
from tqdm import tqdm
final_sent, final_lab = [], []
for sent, lab in tqdm(zip(sentences, labels)):
    final_sent.append(sent)
    final_lab.append(lab)
    for func in (labelwise_replace, mention_replace, segments_shuffle):
        for i in range(3):
            new_sent, new_lab = func(sent, lab)
            final_sent.append(new_sent)
            final_lab.append(new_lab)

12573it [1:40:51,  2.08it/s]


In [19]:
with open('bc2gm/train_aug.tsv', 'a') as file:
    for sent, lab in zip(final_sent, final_lab):
        for word, label in zip(sent, lab):
            file.write(word + '\t' + label + '\n')
        file.write('\n')