In [1]:
import logging
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
from seqeval.metrics import f1_score, precision_score, recall_score
from torch import nn
from torchcrf import CRF
from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
    EarlyStoppingCallback,
    IntervalStrategy,
    BertPreTrainedModel,
    AutoModel,
    BertModel,
)
import torch
import torch.nn.functional as F
# from transformers.modeling_outputs import TokenClassifierOutput
from torch.optim.lr_scheduler import CosineAnnealingLR


from ultil_ner import NerDataset, get_labels
class BertCRFModel(nn.Module):
    def __init__(self, config, model_name_or_path):
        super(BertCRFModel, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name_or_path, config=config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits,)
        if labels is not None:
            loss = self.crf(emissions = logits, tags=labels, mask=attention_mask.byte())
            outputs =(-1*loss,)+outputs
            return outputs 
        else:
            return self.crf.decode(logits, attention_mask.byte())

2023-08-20 18:07:19.838754: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
2023-08-20 18:07:21.248743: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2023-08-20 18:07:21.248794: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2023-08-20 18:07:21.248887: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-08-20 18:07:21.248993: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 2080 Ti computeCapability: 7.5
coreClock: 1.65GHz coreCount: 68 deviceMemorySize: 10.75GiB deviceMemoryBandwidth: 573.69GiB/s
2023-08-20 18:07:21.249011: I tensorflow/stream_executor/platform

In [2]:
id2label= {i: label for i, label in enumerate(get_labels())}

In [3]:
from transformers import BertTokenizer
from torch.utils.data import DataLoader
from transformers import AdamW
from seqeval.metrics import accuracy_score, f1_score, classification_report
from tqdm import tqdm
from seqeval.metrics import f1_score, precision_score, recall_score
output_dir= '/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/output/pubmedbert-crf'
device= 'cuda'
# Training function
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

def idtoLabel(pr, labels=id2label):
    res= []
    for i in pr:
        res.append(labels[i])
    return res
def train(model, epoch, train_dataloader, val_dataloader, optimizer, scheduler,  device='cuda'):
    model.train()
    total_loss = 0
    f1 = 0
    for i in range(epoch):
        print('-----------------Epoch: '+ str(i)+ '-----------------')
        for batch in tqdm(train_dataloader, desc="Training"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            loss = model(input_ids, attention_mask, token_type_ids, labels=labels)[0]
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_dataloader)
        res= evaluate(model, val_dataloader, device)
        scheduler.step()
        if res['f1_macro']> f1:
            print('Epoch: '+ str(i))
            print(res)
            f1= res['f1_macro']
            torch.save(model.state_dict(), output_dir+'/epoch_'+ str(i)+ '_f1_'+ str(f1)+ '.pt')


# Evaluation function
def evaluate(model, eval_dataloader, device='cuda'):
    model.eval()
    predictions = []
    true_labels = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels']

            output = model(input_ids, attention_mask, token_type_ids)
            predicted_labels = output[0]

            predictions.append(idtoLabel(predicted_labels))
            true_labels.append(idtoLabel(labels.numpy().tolist()[0][0: len(predicted_labels)]))
            # if i<5: 
            #     print(predictions)
            #     print(true_labels)

    # Flatten the predictions and true labels lists
    # predictions = [p for sublist in predictions for p in sublist]
    # true_labels = [l for sublist in true_labels for l in sublist]

    # Calculate metrics
    return {"precision": precision_score(true_labels, predictions),
            "recall": recall_score(true_labels, predictions),
            "f1": f1_score(true_labels, predictions),
            "f1_macro": f1_score(true_labels, predictions, average="macro"),
            "recall_macro": recall_score(true_labels, predictions, average="macro"),
            "precision_macro": precision_score(true_labels, predictions, average="macro")
            }
                


# Specify device (e.g., 'cuda' for GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model
labels = get_labels()
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)

    # Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
        '/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/model/BiomedNLP-PubMedBERT-base-uncased-abstract',
        num_labels=num_labels,
        id2label=label_map,
        label2id={label: i for i, label in enumerate(labels)},
)

tokenizer = AutoTokenizer.from_pretrained(
        '/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/model/BiomedNLP-PubMedBERT-base-uncased-abstract',
        use_fast=False,
)

model= BertCRFModel(model_name_or_path='/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/model/BiomedNLP-PubMedBERT-base-uncased-abstract', config=config)
model.to(device)
# Get datasets
train_dataset =NerDataset(
            data_dir='/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/data/Preprocess_PubmedBert_6class',
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=512,
            overwrite_cache=False,
            mode= 'Train'
        )

eval_dataset =  NerDataset(
            data_dir='/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/data/Preprocess_PubmedBert_6class',
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=512,
            overwrite_cache=False,
            mode='Dev',
        )



def collate_fn(batch):
    input_ids = [feature.input_ids for feature in batch]
    attention_mask = [feature.attention_mask for feature in batch]
    token_type_ids = [feature.token_type_ids for feature in batch]
    label_ids = [feature.label_ids for feature in batch]

    # Convert the label_ids to a tensor
    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    token_type_ids = torch.tensor(token_type_ids)
    label_ids = torch.tensor(label_ids)
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids, 'labels': label_ids}

    # return input_ids, attention_mask, token_type_ids, label_ids

train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=collate_fn)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Specify training parameters
num_epochs = 20
learning_rate = 7e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
# evaluate(model, eval_dataloader, device='cuda')
train(model, num_epochs, train_dataloader, eval_dataloader, optimizer, scheduler, device)

Some weights of the model checkpoint at /media/data3/users/longnd/ehr-relation-extraction/biobert_ner/model/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


-----------------Epoch: 0-----------------


Training: 100%|██████████| 430/430 [04:29<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1106/1106 [00:46<00:00, 24.01it/s]


Epoch: 0
{'precision': 0.8493227990970654, 'recall': 0.8519671667138409, 'f1': 0.850642927794263, 'f1_macro': 0.850642927794263, 'recall_macro': 0.8519671667138409, 'precision_macro': 0.8493227990970654}
-----------------Epoch: 1-----------------


Training: 100%|██████████| 430/430 [04:24<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:46<00:00, 23.95it/s]


Epoch: 1
{'precision': 0.862675089261192, 'recall': 0.8890461364279649, 'f1': 0.8756621131865068, 'f1_macro': 0.8756621131865068, 'recall_macro': 0.8890461364279649, 'precision_macro': 0.862675089261192}
-----------------Epoch: 2-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.07it/s]


-----------------Epoch: 3-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.08it/s]


-----------------Epoch: 4-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:46<00:00, 23.99it/s]


-----------------Epoch: 5-----------------


Training: 100%|██████████| 430/430 [04:24<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:46<00:00, 23.77it/s]


-----------------Epoch: 6-----------------


Training: 100%|██████████| 430/430 [04:50<00:00,  1.48it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.19it/s]


-----------------Epoch: 7-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.29it/s]


Epoch: 7
{'precision': 0.8712351478308925, 'recall': 0.8924426832720068, 'f1': 0.8817114093959733, 'f1_macro': 0.8817114093959733, 'recall_macro': 0.8924426832720068, 'precision_macro': 0.8712351478308925}
-----------------Epoch: 8-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:46<00:00, 23.85it/s]


-----------------Epoch: 9-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.22it/s]


-----------------Epoch: 10-----------------


Training: 100%|██████████| 430/430 [04:24<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.20it/s]


-----------------Epoch: 11-----------------


Training: 100%|██████████| 430/430 [04:22<00:00,  1.64it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.17it/s]


-----------------Epoch: 12-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.30it/s]


-----------------Epoch: 13-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.07it/s]


-----------------Epoch: 14-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.39it/s]


-----------------Epoch: 15-----------------


Training: 100%|██████████| 430/430 [04:23<00:00,  1.63it/s]
Evaluating: 100%|██████████| 1106/1106 [00:46<00:00, 23.99it/s]


-----------------Epoch: 16-----------------


Training: 100%|██████████| 430/430 [04:50<00:00,  1.48it/s]
Evaluating: 100%|██████████| 1106/1106 [00:45<00:00, 24.19it/s]


-----------------Epoch: 17-----------------


Training:   2%|▏         | 8/430 [00:05<04:54,  1.43it/s]


KeyboardInterrupt: 