In [16]:
import os
from utils import Sample, chunk_spans, evaluate
import torch
import torch.nn as nn
from torch.optim import Adam
from transformers import AutoModel, AutoTokenizer, AutoConfig

import numpy as np

def get_files(directory,ext):
    doc_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(ext):
                doc_files.append(os.path.join(root, file))
    return doc_files

all_txt=get_files('SampleData','.txt')


idx_to_tag = {'O': 0, 'B-Tox': 1, 'I-Tox': 2}

samples = list()


lm_version='bert-base-cased'

tokenizer=AutoTokenizer.from_pretrained(lm_version)

for file in all_txt:
    text = open(file, 'r').read()

    tokenizer_out = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)

    new_sample = Sample(os.path.basename(file), text, tokenizer_out['input_ids'], tokenizer_out['offset_mapping'],
                        tokenizer.tokenize(text))

    annotation_file = open(file.replace('.txt', '.ann'), 'r')

    annotation_concept_map = {}
    for line in annotation_file:
        line = line.split('\n')[0].split('\t')
        if line[0][0] == 'T':
            type = line[1].split(' ')[0]
            id = line[0]
            if type == 'SideEffect':
                span = (int(line[1].split(' ')[1]), int(line[1].split(' ')[-1]))
                new_sample.add_anno(id, 'SideEffect', span)
        else:
            att_id = line[0]
            span_id = line[1].split(' ')[1]
            cui = line[1].split(' ')[2]
            new_sample.spans[span_id][2] = cui

    new_sample.add_labels()

    samples.append(new_sample)

In [17]:
class NER(torch.nn.Module):
    def __init__(self,language_model):
        super(NER, self).__init__()
        self.config = AutoConfig.from_pretrained(language_model)
        self.lm=AutoModel.from_pretrained(language_model)
        self.num_classes = 3
        self.projection=nn.Linear(self.config.hidden_size,self.num_classes)

    def forward(self,input_ids):
        hiddens=self.lm(input_ids)
        return self.projection(hiddens['last_hidden_state']).squeeze()

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model=NER(lm_version).to(device)
optimizer=Adam(model.parameters(),lr=0.00001)
loss_function=nn.CrossEntropyLoss()


n_epoch=20

for epoch in range(n_epoch):
    all_loss=list()
    for sample in samples:
        model.zero_grad()
        max_len=model.config.max_position_embeddings

        token_ids=torch.tensor([sample.token_ids[0:max_len]],dtype=torch.long).to(device)
        pred=model(token_ids)
        target=torch.tensor(sample.labels[0:max_len],dtype=torch.long).to(device)
        loss=loss_function(pred,target)
        all_loss.append(loss.detach().cpu().item())
        loss.backward()
        optimizer.step()
    print('Average loss=',np.mean(all_loss))


RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [11]:
with torch.no_grad():
    model.eval()
    for sample in samples:
        max_len = model.config.max_position_embeddings
        token_ids = torch.tensor([sample.token_ids[0:max_len]], dtype=torch.long).to(device)
        pred = model(token_ids)

        pred = list(torch.argmax(pred, dim=1).detach().cpu().numpy())

        diff = max(len(sample.token_ids) - max_len, 0)
        pred = pred + [0] * diff

        pred = [(sample.token_spans[span[0]][0], sample.token_spans[span[1]][1], 'SideEffect') for span in
                chunk_spans(pred)]
        gold = [(sample.spans[key][1][0], sample.spans[key][1][1], sample.spans[key][0]) for key in
                [key for key in sample.spans]]

        if len(pred) == 0 and len(gold) == 0:
            continue

        result=evaluate(gold,pred)
        print(result)


{'strict': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 'overlapping': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}}
{'strict': {'precision': 1.0, 'recall': 0.3333333333333333, 'f1': 0.5}, 'overlapping': {'precision': 1.0, 'recall': 0.3333333333333333, 'f1': 0.5}}
