In [None]:
#BioLinkBERT-RNN using NCBI dataset


import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler

# Load the dataset
dataset = load_dataset("ncbi/ncbi_disease")
train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']

# Load the model and tokenizer
model_name = "michiyasunaga/BioLinkBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) 
encoder_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

# Define the BioLinkBERT-BiGRU model
class BioLinkBERTBiGRU(nn.Module):
    def __init__(self, encoder_model, hidden_dim, num_labels):
        super(BioLinkBERTBiGRU, self).__init__()
        self.encoder = encoder_model
        self.gru = nn.GRU(input_size=encoder_model.config.hidden_size, hidden_size=hidden_dim, bidirectional=True, batch_first=True)
        self.classifier = nn.Linear(hidden_dim * 2, num_labels)  # Hidden_dim * 2 for bidirectional GRU

    def forward(self, input_ids, attention_mask, labels=None):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        # Get the last hidden state
        encoder_outputs = encoder_outputs.hidden_states[-1]
        # Get GRU outputs
        output, hn = self.gru(encoder_outputs)  # hn will have 2 hidden states (forward and backward)
        # Concatenate forward and backward hidden states
        logits = self.classifier(torch.cat((hn[-2], hn[-1]), dim=-1))  # Concatenate forward and backward GRU outputs

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
        return {"loss": loss, "logits": logits}

# Set up the model
num_labels = 3
hidden_dim = 512
model = BioLinkBERTBiGRU(encoder_model, hidden_dim, num_labels)

# Preprocess the data
def preprocess_function(examples):
    inputs = [x if isinstance(x, list) else [x] for x in examples['tokens']]
    labels = examples['ner_tags']

    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length', is_split_into_words=True)

    all_new_labels = []
    for i in range(len(examples['tokens'])):
        word_ids = model_inputs.word_ids(batch_index=i)
        new_labels = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                new_labels.append(-100)
            elif word_idx != previous_word_idx:
                new_labels.append(labels[i][word_idx])
            else:
                new_labels.append(-100)
            previous_word_idx = word_idx
        all_new_labels.append(new_labels)

    model_inputs["labels"] = all_new_labels
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True)
val_dataset = val_dataset.map(preprocess_function, batched=True)
test_dataset = test_dataset.map(preprocess_function, batched=True)

train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=1e-5,
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Define metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    true_labels = [l for label_list in labels for l in label_list if l != -100]
    true_predictions = []
    for i, label_list in enumerate(labels):
        true_predictions.extend([predictions[i]] * len([l for l in label_list if l != -100]))

    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, true_predictions, average='weighted', zero_division=0)
    acc = accuracy_score(true_labels, true_predictions)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = nn.CrossEntropyLoss()

        main_labels = []
        for label_sequence in labels:
            try:
                main_label_index = next((i for i, label in enumerate(label_sequence) if label != -100), None)
                if main_label_index is not None:
                    main_labels.append(label_sequence[main_label_index])
                else:
                    main_labels.append(0)
            except StopIteration:
                main_labels.append(0)

        main_labels = torch.tensor(main_labels, device=logits.device, dtype=torch.long)
        loss = loss_fct(logits, main_labels)
        return (loss, outputs) if return_outputs else loss

optimizer = torch.optim.NAdam(model.parameters(), lr=1e-5)
lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer)

# Initialize Trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print("Training Results:", eval_results)

val_results = trainer.evaluate(eval_dataset=val_dataset)
print("Validation Results:", val_results)

test_results = trainer.evaluate(eval_dataset=test_dataset)
print("Test Results:", test_results)
