## Fine-tuning BioBERT Model using DrugProt Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm

# Data loading functions

def load_drugprot_abstracts(file_path):
    """
    Load abstracts from the DrugProt dataset.

    Args:
    file_path (str): Path to the abstracts file.

    Returns:
    dict: A dictionary with PMIDs as keys and concatenated title + abstract as values.
    """
    abstracts = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            pmid, title, abstract = line.strip().split('\t')
            abstracts[pmid] = title + ' ' + abstract
    return abstracts

def load_drugprot_entities(file_path):
    """
    Load entity annotations from the DrugProt dataset.

    Args:
    file_path (str): Path to the entities file.

    Returns:
    list: A list of tuples containing entity information.
    """
    entities = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) == 6:
                pmid, entity_id, entity_type, start, end, text = parts
                entities.append((pmid, entity_id, entity_type, int(start), int(end), text))
            else:
                print(f"Unexpected number of fields in line: {line.strip()}")
    return entities

def load_drugprot_relations(file_path):
    """
    Load relation annotations from the DrugProt dataset.

    Args:
    file_path (str): Path to the relations file.

    Returns:
    list: A list of tuples containing relation information.
    """
    relations = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) == 4:
                pmid, rel_type, arg1, arg2 = parts
                relations.append((pmid, rel_type, arg1, arg2))
            else:
                print(f"Unexpected number of fields in line: {line.strip()}")
    return relations

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

# Named Entity Recognition (NER) data preparation and dataset creation

def prepare_ner_data(entities, abstracts):
    """
    Prepare data for NER training by aligning entity annotations with tokenized text.

    Args:
    entities (list): List of entity annotations.
    abstracts (dict): Dictionary of abstracts.

    Returns:
    list: A list of tuples containing tokens and their corresponding NER labels.
    """
    ner_data = []
    for pmid, abstract_text in tqdm(abstracts.items(), desc="Preparing NER data"):
        tokens = abstract_text.split()
        labels = ['O'] * len(tokens)  # Initialize all tokens as 'Outside' entities
        for e_pmid, _, entity_type, start, end, _ in entities:
            if e_pmid == pmid:
                entity_tokens = abstract_text[start:end].split()
                start_token = len(abstract_text[:start].split())
                for i, token in enumerate(entity_tokens):
                    if start_token + i < len(labels):
                        if i == 0:
                            labels[start_token + i] = f'B-{entity_type}'  # Beginning of entity
                        else:
                            labels[start_token + i] = f'I-{entity_type}'  # Inside of entity
        ner_data.append((tokens, labels))
    return ner_data

class NERDataset(Dataset):
    """
    Custom Dataset for NER task using DrugProt data.
    Initialize the NERDataset.

    Args:
    data (list): List of tuples, each containing tokens and their corresponding NER labels.
    tokenizer: The tokenizer to use for encoding the text (BioBERT tokenizer).
    max_len (int): Maximum length of the input sequence.

    The label2id dictionary maps NER labels to integer IDs:
    - 'O': Outside of a named entity
    - 'B-CHEMICAL': Beginning of a chemical entity
    - 'I-CHEMICAL': Inside of a chemical entity
    - 'B-GENE-Y': Beginning of a gene/protein entity that can be normalized
    - 'I-GENE-Y': Inside of a gene/protein entity that can be normalized
    - 'B-GENE-N': Beginning of a gene/protein entity that cannot be normalized
    - 'I-GENE-N': Inside of a gene/protein entity that cannot be normalized

    """
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label2id = {'O': 0, 'B-CHEMICAL': 1, 'I-CHEMICAL': 2, 'B-GENE-Y': 3, 'I-GENE-Y': 4, 'B-GENE-N': 5, 'I-GENE-N': 6}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, labels = self.data[idx]
        encoding = self.tokenizer(tokens,
                                  is_split_into_words=True,
                                  max_length=self.max_len,
                                  padding='max_length',
                                  truncation=True,
                                  return_tensors='pt')

        # Convert string labels to IDs
        label_ids = [self.label2id[label] for label in labels]
        # Pad or truncate label_ids to match max_len
        label_ids = label_ids[:self.max_len] + [self.label2id['O']] * (self.max_len - len(label_ids))

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label_ids)
        }


# Relationship extraction dataset creation

class RelationDataset(Dataset):
    """
    Custom Dataset for Relation Extraction task using DrugProt data.
    """
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.relation2id = {rel: i for i, rel in enumerate(set(d[3] for d in data))}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, entity1, entity2, relation = self.data[idx]
        # Combine text and entities with special separation tokens
        combined_text = f"{text} [SEP] {entity1} [SEP] {entity2}"
        encoded = self.tokenizer.encode_plus(
            combined_text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoded['input_ids'].flatten(),
            'attention_mask': encoded['attention_mask'].flatten(),
            'labels': torch.tensor(self.relation2id[relation])
        }

def prepare_relation_data(relations, abstracts, entities):
    """
    Prepare data for Relation Extraction training.

    Args:
    relations (list): List of relation annotations.
    abstracts (dict): Dictionary of abstracts.
    entities (list): List of entity annotations.

    Returns:
    list: A list of tuples containing text, entity1, entity2, and relation type.
    """
    data = []
    skipped_no_abstract = 0
    skipped_no_entity1 = 0
    skipped_no_entity2 = 0

    # Create a dictionary for faster entity lookup
    entity_dict = {}
    for e in entities:
        pmid, entity_id, entity_type, start, end, text = e
        if pmid not in entity_dict:
            entity_dict[pmid] = {}
        entity_dict[pmid][entity_id] = text

    print(f"Number of PMIDs in entity_dict: {len(entity_dict)}")
    print(f"Sample entity_dict entry: {list(entity_dict.items())[0]}")

    for pmid, rel_type, arg1, arg2 in relations:
        if pmid not in abstracts:
            skipped_no_abstract += 1
            continue

        text = abstracts[pmid]

        # Extract entity IDs from arg1 and arg2
        entity1_id = arg1.split(':')[1]
        entity2_id = arg2.split(':')[1]

        if pmid not in entity_dict:
            print(f"PMID {pmid} not found in entity_dict")
        elif entity1_id not in entity_dict[pmid]:
            print(f"Entity1 ID {entity1_id} not found for PMID {pmid}")
            print(f"Available entity IDs for this PMID: {list(entity_dict[pmid].keys())}")

        if pmid not in entity_dict or entity1_id not in entity_dict[pmid]:
            skipped_no_entity1 += 1
            continue

        if entity2_id not in entity_dict[pmid]:
            skipped_no_entity2 += 1
            continue

        entity1_text = entity_dict[pmid][entity1_id]
        entity2_text = entity_dict[pmid][entity2_id]

        data.append((text, entity1_text, entity2_text, rel_type))

    print(f"Total relations: {len(relations)}")
    print(f"Skipped due to missing abstract: {skipped_no_abstract}")
    print(f"Skipped due to missing entity1: {skipped_no_entity1}")
    print(f"Skipped due to missing entity2: {skipped_no_entity2}")
    print(f"Final number of relations: {len(data)}")

    return data

# Model training function

def train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs, device):
    """
    Train the model and perform validation.

    Args:
    model: The model to be trained.
    train_loader: DataLoader for training data.
    val_loader: DataLoader for validation data.
    optimizer: Optimizer for updating model parameters.
    scheduler: Learning rate scheduler.
    num_epochs (int): Number of training epochs.
    device: Device to run the model on (CPU or GPU).

    Returns:
    model: The trained model.
    """
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")

        # Validation phase
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                total_val_loss += outputs.loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}")

    return model

In [None]:
if __name__ == "__main__":
    # Define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    abstracts = load_drugprot_abstracts('/content/drive/MyDrive/drugprot-data/training/drugprot_training_abstracts.tsv')
    entities = load_drugprot_entities('/content/drive/MyDrive/drugprot-data/training/drugprot_training_entities.tsv')
    relations = load_drugprot_relations('/content/drive/MyDrive/drugprot-data/training/drugprot_training_relations.tsv')

    # Print some debug information
    print(f"Number of abstracts: {len(abstracts)}")
    print(f"Number of entities: {len(entities)}")
    print(f"Number of relations: {len(relations)}")

    # Prepare data
    ner_data = prepare_ner_data(entities, abstracts)
    relation_data = prepare_relation_data(relations, abstracts, entities)

    if len(relation_data) == 0:
        print("No valid relations found. Check your data and the prepare_relation_data function.")
    else:
        # Split data
        ner_train, ner_val = train_test_split(ner_data, test_size=0.1, random_state=42)
        relation_train, relation_val = train_test_split(relation_data, test_size=0.1, random_state=42)

        # Initialize tokenizer
        tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

        # Define model save paths
        ner_model_save_path = '/content/drive/MyDrive/trained_ner_model'
        relation_model_save_path = '/content/drive/MyDrive/trained_relation_model'
        save_path_tokenizer = '/content/drive/MyDrive/tokenizer'

        # NER model
        ner_model = AutoModelForTokenClassification.from_pretrained("dmis-lab/biobert-v1.1", num_labels=7).to(device)
        ner_train_dataset = NERDataset(ner_train, tokenizer, max_len=256)
        ner_val_dataset = NERDataset(ner_val, tokenizer, max_len=256)
        ner_train_loader = DataLoader(ner_train_dataset, batch_size=16, shuffle=True)
        ner_val_loader = DataLoader(ner_val_dataset, batch_size=16)

        ner_optimizer = AdamW(ner_model.parameters(), lr=2e-5)
        ner_scheduler = get_linear_schedule_with_warmup(ner_optimizer, num_warmup_steps=0, num_training_steps=len(ner_train_loader) * 5)

        trained_ner_model = train_model(ner_model, ner_train_loader, ner_val_loader, ner_optimizer, ner_scheduler, num_epochs=5, device=device)
        trained_ner_model.save_pretrained(model_save_path)

        # Add this before initializing the relation model
        num_relation_classes = len(set(r[3] for r in relation_data))
        print(f"Number of relation classes: {num_relation_classes}")

        # Relation extraction model
        relation_model = AutoModelForSequenceClassification.from_pretrained("dmis-lab/biobert-v1.1", num_labels=num_relation_classes).to(device)
        relation_train_dataset = RelationDataset(relation_train, tokenizer, max_len=256)
        relation_val_dataset = RelationDataset(relation_val, tokenizer, max_len=256)
        relation_train_loader = DataLoader(relation_train_dataset, batch_size=16, shuffle=True)
        relation_val_loader = DataLoader(relation_val_dataset, batch_size=16)

        relation_optimizer = AdamW(relation_model.parameters(), lr=2e-5)
        relation_scheduler = get_linear_schedule_with_warmup(relation_optimizer, num_warmup_steps=0, num_training_steps=len(relation_train_loader) * 5)

        trained_relation_model = train_model(relation_model, relation_train_loader, relation_val_loader, relation_optimizer, relation_scheduler, num_epochs=5, device=device)
        trained_relation_model.save_pretrained(relation_model_save_path)
        tokenizer.save_pretrained(save_path_tokenizer)

        print("Training completed. Models saved in Google Drive.")

Using device: cuda
Number of abstracts: 3500
Number of entities: 89529
Number of relations: 17288


Preparing NER data: 100%|██████████| 3500/3500 [00:29<00:00, 120.27it/s]


Number of PMIDs in entity_dict: 3500
Sample entity_dict entry: ('11808879', {'T1': 'diazoxide', 'T2': 'Diazoxide', 'T3': 'diazoxide', 'T4': 'glutamate', 'T5': 'glucose', 'T6': 'glucose', 'T7': 'diazoxide', 'T8': 'insulin', 'T9': 'SUR1', 'T10': 'KIR6.2', 'T11': 'SUR1', 'T12': 'KIR6.2', 'T13': 'glutamate dehydrogenase', 'T14': 'glucokinase'})
Total relations: 17288
Skipped due to missing abstract: 0
Skipped due to missing entity1: 0
Skipped due to missing entity2: 0
Final number of relations: 17288


Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/5: 100%|██████████| 197/197 [02:18<00:00,  1.42it/s]


Epoch 1/5, Training Loss: 0.5718


Validation: 100%|██████████| 22/22 [00:05<00:00,  4.11it/s]


Epoch 1/5, Validation Loss: 0.5037


Epoch 2/5: 100%|██████████| 197/197 [02:23<00:00,  1.37it/s]


Epoch 2/5, Training Loss: 0.5155


Validation: 100%|██████████| 22/22 [00:05<00:00,  4.10it/s]


Epoch 2/5, Validation Loss: 0.4948


Epoch 3/5: 100%|██████████| 197/197 [02:23<00:00,  1.37it/s]


Epoch 3/5, Training Loss: 0.5019


Validation: 100%|██████████| 22/22 [00:05<00:00,  4.13it/s]


Epoch 3/5, Validation Loss: 0.4854


Epoch 4/5: 100%|██████████| 197/197 [02:23<00:00,  1.37it/s]


Epoch 4/5, Training Loss: 0.4914


Validation: 100%|██████████| 22/22 [00:05<00:00,  4.13it/s]


Epoch 4/5, Validation Loss: 0.4768


Epoch 5/5: 100%|██████████| 197/197 [02:23<00:00,  1.37it/s]


Epoch 5/5, Training Loss: 0.4839


Validation: 100%|██████████| 22/22 [00:05<00:00,  4.12it/s]


Epoch 5/5, Validation Loss: 0.4741
Number of relation classes: 13


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/5: 100%|██████████| 973/973 [11:40<00:00,  1.39it/s]


Epoch 1/5, Training Loss: 1.3901


Validation: 100%|██████████| 109/109 [00:25<00:00,  4.21it/s]


Epoch 1/5, Validation Loss: 1.1036


Epoch 2/5: 100%|██████████| 973/973 [11:39<00:00,  1.39it/s]


Epoch 2/5, Training Loss: 1.0215


Validation: 100%|██████████| 109/109 [00:25<00:00,  4.23it/s]


Epoch 2/5, Validation Loss: 1.0073


Epoch 3/5: 100%|██████████| 973/973 [11:40<00:00,  1.39it/s]


Epoch 3/5, Training Loss: 0.9024


Validation: 100%|██████████| 109/109 [00:25<00:00,  4.22it/s]


Epoch 3/5, Validation Loss: 0.9832


Epoch 4/5: 100%|██████████| 973/973 [11:39<00:00,  1.39it/s]


Epoch 4/5, Training Loss: 0.8305


Validation: 100%|██████████| 109/109 [00:25<00:00,  4.26it/s]


Epoch 4/5, Validation Loss: 0.9604


Epoch 5/5: 100%|██████████| 973/973 [11:39<00:00,  1.39it/s]


Epoch 5/5, Training Loss: 0.7787


Validation: 100%|██████████| 109/109 [00:25<00:00,  4.25it/s]


Epoch 5/5, Validation Loss: 0.9650
Training completed. Models saved in Google Drive.


In [None]:
trained_ner_model.save_pretrained(ner_model_save_path)

In [None]:
print("Raw entity data sample:")
with open('/content/drive/MyDrive/drugprot-data/training/drugprot_training_entities.tsv', 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        print(line.strip())
        if i == 4:  # Print first 5 lines
            break

print("\nRaw relation data sample:")
with open('/content/drive/MyDrive/drugprot-data/training/drugprot_training_relations.tsv', 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        print(line.strip())
        if i == 4:  # Print first 5 lines
            break

Raw entity data sample:
11808879	T1	CHEMICAL	1165	1174	diazoxide
11808879	T2	CHEMICAL	1450	1459	Diazoxide
11808879	T3	CHEMICAL	1901	1910	diazoxide
11808879	T4	CHEMICAL	1993	2002	glutamate
11808879	T5	CHEMICAL	917	924	glucose

Raw relation data sample:
23017395	INHIBITOR	Arg1:T15	Arg2:T21
23017395	INHIBITOR	Arg1:T16	Arg2:T21
12181427	PART-OF	Arg1:T3	Arg2:T22
12181427	INHIBITOR	Arg1:T6	Arg2:T23
12181427	INHIBITOR	Arg1:T7	Arg2:T23
