In [None]:
pip install torchcrf
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments, DataCollatorForTokenClassification, EvalPrediction, TrainerCallback, BertModel, BertTokenizer
from sklearn.metrics import accuracy_score, classification_report, precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
from TorchCRF import CRF

train_df = pd.read_csv('/kaggle/input/biobert/train.csv')
dev_df = pd.read_csv('/kaggle/input/biobert/dev.csv')
test_df = pd.read_csv('/kaggle/input/biobert/test.csv')

class NERDataset(Dataset):
    def __init__(self, dataframe, tokenizer, label_map, max_len):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.label_map = label_map
        self.max_len = max_len
        self.sentences, self.labels = self._split_sentences_and_labels(dataframe)

    def _split_sentences_and_labels(self, dataframe):
        sentences = []
        labels = []
        sentence = []
        label = []
        for idx in range(len(dataframe)):
            word = dataframe.iloc[idx, 0]
            tag = dataframe.iloc[idx, 1]
            if pd.isna(word) or word == '':
                if sentence:
                    sentences.append(sentence)
                    labels.append(label)
                    sentence = []
                    label = []
            else:
                sentence.append(word)
                label.append(tag)
        if sentence:
            sentences.append(sentence)
            labels.append(label)
        return sentences, labels

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        sentence = self.sentences[index]
        tags = self.labels[index]

        tokens = []
        label_ids = []

        for word, label in zip(sentence, tags):
            word_tokens = self.tokenizer.tokenize(word)
            tokens.extend(word_tokens)
            label_ids.extend([self.label_map[label]] * len(word_tokens))

        tokens = tokens[:self.max_len - 2]
        label_ids = label_ids[:self.max_len - 2]

        tokens = ['[CLS]'] + tokens + ['[SEP]']
        label_ids = [33] + label_ids + [33]

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1] * len(input_ids)

        padding_length = self.max_len - len(input_ids)
        input_ids = input_ids + ([0] * padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        label_ids = label_ids + ([33] * padding_length)

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(label_ids, dtype=torch.long)
        }

MAX_LEN = 204
train_dataset = NERDataset(train_df, tokenizer, label_map, MAX_LEN)
dev_dataset = NERDataset(dev_df, tokenizer, label_map, MAX_LEN)
test_dataset = NERDataset(test_df, tokenizer, label_map, MAX_LEN)

class BertBiLSTMCRF(nn.Module):
    def __init__(self, bert_model_name, num_labels, hidden_dim=768, lstm_dim=256, class_weights=None):
        super(BertBiLSTMCRF, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        for param in self.bert.parameters():
            param.requires_grad = False

        self.lstm = nn.LSTM(hidden_dim, lstm_dim, batch_first=True, bidirectional=True)
        self.hidden2tag = nn.Linear(lstm_dim * 2, num_labels)
        self.crf = CRF(num_labels)
        self.num_labels = num_labels
        self.class_weights = class_weights

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        lstm_out, _ = self.lstm(sequence_output)
        emissions = self.hidden2tag(lstm_out)

        if self.class_weights is not None:
            weight_tensor = self.class_weights.clone().detach()
            weight_tensor = weight_tensor.to(emissions.device)
            weight_tensor = weight_tensor.view(1, 1, -1)
            weighted_emissions = emissions * weight_tensor
            emissions = weighted_emissions

        if labels is not None:
            log_likelihood = self.crf(emissions, labels, mask=attention_mask.byte())
            loss = -log_likelihood.mean()
            return loss, emissions
        else:
            return emissions

    def predict(self, input_ids, attention_mask=None):
        emissions = self.forward(input_ids, attention_mask)
        prediction = self.crf.viterbi_decode(emissions, mask=attention_mask.byte())
        return prediction

def compute_train_accuracy(trainer, dataloader):
    trainer.model.eval()
    all_preds = []
    all_labels = []

    for batch in dataloader:
        inputs = {k: v.to(trainer.args.device) for k, v in batch.items()}
        with torch.no_grad():
            emissions = trainer.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
            preds = trainer.model.crf.viterbi_decode(emissions, mask=inputs['attention_mask'].byte())

        labels = inputs['labels'].cpu().numpy()
        attention_mask = inputs['attention_mask'].cpu().numpy()

        for i in range(labels.shape[0]):
            label_seq = labels[i]
            pred_seq = preds[i]
            mask_seq = attention_mask[i]

            active_labels = label_seq[mask_seq == 1]
            active_preds = pred_seq[:len(active_labels)]

            all_labels.extend(active_labels)
            all_preds.extend(active_preds)

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    active_indices = all_labels != 33

    all_labels = all_labels[active_indices]
    all_preds = all_preds[active_indices]

    return accuracy_score(all_labels, all_preds)

def compute_metrics(p: EvalPrediction):
    predictions, labels = p.predictions, p.label_ids

    true_labels = labels.flatten().tolist()
    pred_labels = np.argmax(predictions, axis=2).flatten().tolist()

    active_indices = labels.flatten() != 33
    true_labels = [label for label, active in zip(true_labels, active_indices) if active]
    pred_labels = [label for label, active in zip(pred_labels, active_indices) if active]

    accuracy = accuracy_score(true_labels, pred_labels)

    return {"accuracy": accuracy}

class LoggingCallback(TrainerCallback):
    def __init__(self):
        self.train_losses = []
        self.train_accuracies = []
        self.eval_losses = []
        self.eval_accuracies = []

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        logs = state.log_history[-1]
        epoch = logs.get("epoch", "N/A")
        train_loss = logs.get("loss", "N/A")
        train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size, collate_fn=data_collator)
        train_accuracy = compute_train_accuracy(trainer, train_dataloader)

        validation_metrics = trainer.evaluate(test_dataset)
        validation_loss = validation_metrics['eval_loss']
        validation_accuracy = validation_metrics['eval_accuracy']

        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_accuracy)
        self.eval_losses.append(validation_loss)
        self.eval_accuracies.append(validation_accuracy)

        print(f"Epoch: {epoch}")
        print(f"Training Loss: {train_loss}")
        print(f"Training Accuracy: {train_accuracy}")
        print(f"Validation Loss: {validation_loss}")
        print(f"Validation Accuracy: {validation_accuracy}")

    def plot_metrics(self):
        epochs = range(1, len(self.train_losses) + 1)

        plt.figure(figsize=(14, 6))

        plt.subplot(1, 2, 1)
        plt.plot(epochs, self.train_losses, 'b', label='Training loss')
        plt.plot(epochs, self.eval_losses, 'r', label='Validation loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(epochs, self.train_accuracies, 'b', label='Training accuracy')
        plt.plot(epochs, self.eval_accuracies, 'r', label='Validation accuracy')
        plt.title('Training and Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.show()

class_supports = [103, 38, 2197, 3827, 476, 22, 8886, 145, 803, 286, 1195, 61, 247, 166, 2693, 609, 24, 17, 1333, 2582, 140, 2, 5463, 7, 336, 27, 439, 8, 42, 82, 1660, 310, 74915]
total_samples = sum(class_supports)
num_classes = len(class_supports)

manual_class_weights = []
for i, support in enumerate(class_supports):
    if i == 33:
        continue
    weight = total_samples / (num_classes * support)
    manual_class_weights.append(weight)

manual_class_weights.insert(33, 0.0001)

print("Manual Class Weights:", manual_class_weights)

manual_class_weights = torch.tensor(manual_class_weights, dtype=torch.float)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
manual_class_weights = manual_class_weights.to(device)

model = BertBiLSTMCRF("bionlp/bluebert_pubmed_uncased_L-12_H-768_A-12", num_labels=34, class_weights=manual_class_weights)
model.to(device)
tokenizer = BertTokenizer.from_pretrained("bionlp/bluebert_pubmed_uncased_L-12_H-768_A-12")

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs= 150,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="no"
)

logging_callback = LoggingCallback()

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[logging_callback]
)

trainer.train()

logging_callback.plot_metrics()

def evaluate_on_test(trainer, test_dataset):
    predictions, labels, _ = trainer.predict(test_dataset)

    true_labels = labels.flatten().tolist()
    pred_labels = np.argmax(predictions, axis=2).flatten().tolist()

    active_indices = labels.flatten() != 33
    true_labels = [label for label, active in zip(true_labels, active_indices) if active]
    pred_labels = [label for label, active in zip(pred_labels, active_indices) if active]

    accuracy = accuracy_score(true_labels, pred_labels)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='macro')
    report = classification_report(true_labels, pred_labels, zero_division=0)

    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1 Score: {f1}")
    print(report)

evaluate_on_test(trainer, test_dataset)
evaluate_on_test(trainer, train_dataset)

def predict_and_get_labels(trainer, dataset):
    predictions, labels, _ = trainer.predict(dataset)

    true_labels = labels.flatten().tolist()
    pred_labels = np.argmax(predictions, axis=2).flatten().tolist()

    active_indices = labels.flatten() != 33  # Assuming 33 is the padding label
    true_labels = [label for label, active in zip(true_labels, active_indices) if active]
    pred_labels = [label for label, active in zip(pred_labels, active_indices) if active]

    return true_labels, pred_labels

true_labels, pred_labels = predict_and_get_labels(trainer, test_dataset)

# Compute the confusion matrix
conf_matrix = confusion_matrix(true_labels, pred_labels)

# Optionally, compute other metrics
report = classification_report(true_labels, pred_labels, zero_division=0)
print(report)

# Display the confusion matrix
fig, ax = plt.subplots(figsize=(20, 20))
disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix)
disp.plot(ax=ax, cmap='Blues')
plt.savefig("confmat tes3")
plt.show()

save_path = './bert_bilstm_crf_model.pth'

# Save the model's state dictionary
torch.save(model.state_dict(), save_path)

print(f"Model saved to {save_path}")