In [2]:
import os
import gzip
import shutil
import torch
import urllib.request
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [3]:
def download_medmentions():
    url = "https://github.com/chanzuckerberg/MedMentions/blob/master/full/data/corpus_pubtator.txt.gz?raw=true"
    filename = "./corpus/corpus_pubtator.txt.gz"
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filename)
    
    if not os.path.exists("./corpus/corpus_pubtator.txt"):
        print("Extracting dataset...")
        with gzip.open(filename, 'rb') as f_in:
            with open('corpus_pubtator.txt', 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)

In [4]:
def read_medmentions(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read().strip().split('\n\n')
    
    data = []
    for entry in tqdm(content, desc="Reading entries"):
        lines = entry.split('\n')
        title = lines[0].split('|t|')[1]
        abstract = lines[1].split('|a|')[1]
        text = title + ' ' + abstract
        entities = [line.split('\t') for line in lines[2:] if len(line.split('\t')) > 1]
        data.append((text, entities))
    
    return data

In [5]:
def create_bio_tags(text, entities):
    words = text.split()
    tags = ['O'] * len(words)
    
    for entity in entities:
        start, end, entity_type, umls_id = int(entity[0]), int(entity[1]), entity[2], entity[3]
        start_word = len(text[:start].split())
        end_word = len(text[:end].split())
        
        # Filter for symptom-related semantic types
        symptom_types = ['sosy', 'patf', 'dsyn', 'fndg']
        if any(st in entity_type.lower() for st in symptom_types):
            tags[start_word] = 'B-SYMPTOM'
            for i in range(start_word + 1, end_word):
                tags[i] = 'I-SYMPTOM'
    
    return words, tags

In [7]:
def prepare_medmentions_data():
    download_medmentions()
    
    all_data = read_medmentions('corpus_pubtator.txt')
    
    processed_data = []
    for text, entities in tqdm(all_data, desc="Processing entries"):
        words, tags = create_bio_tags(text, entities)
        processed_data.append((words, tags))
    
    df = pd.DataFrame(processed_data, columns=['text', 'labels'])
    df['text'] = df['text'].apply(lambda x: ' '.join(x))
    df['labels'] = df['labels'].apply(lambda x: ' '.join(x))
    
    train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
    train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)
    
    train_df.to_csv('./data/medmentions_train.csv', index=False)
    val_df.to_csv('./data/medmentions_val.csv', index=False)
    test_df.to_csv('./data/medmentions_test.csv', index=False)
    
    print(f"Saved {len(train_df)} training samples, {len(val_df)} validation samples, and {len(test_df)} test samples.")

In [8]:
def load_medical_dataset(file_path, max_seq_length=128):
    # Load the dataset
    df = pd.read_csv(file_path)
    
    # Assuming your CSV has columns: 'text' and 'labels'
    texts = df['text'].tolist()
    
    # Convert string labels to list of integers
    labels = df['labels'].apply(lambda x: [int(label) for label in x.split()])
    
    # Pad or truncate labels to match max_seq_length
    labels = [label + [0] * (max_seq_length - len(label)) if len(label) < max_seq_length 
              else label[:max_seq_length] for label in labels]
    
    return texts, labels


In [9]:
class MedicalNERDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [10]:
def train_epoch(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0

    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    return total_loss / len(data_loader)

In [12]:
prepare_medmentions_data()

Reading entries: 100%|██████████| 4392/4392 [00:01<00:00, 2868.07it/s]
Processing entries: 100%|██████████| 4392/4392 [00:07<00:00, 597.22it/s]


Saved 3161 training samples, 352 validation samples, and 879 test samples.


In [11]:
class MedMentionsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_map = {'O': 0}  # Start with 'O' as 0
        self.num_labels = 1  # Start with 1 for 'O'

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = str(self.labels[item])

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        # Convert string labels to integers
        label_ids = self.convert_labels_to_ids(label.split())
        
        # Add -100 for special tokens
        label_ids = [-100] + label_ids[:self.max_len-2] + [-100]
        label_ids += [-100] * (self.max_len - len(label_ids))

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label_ids, dtype=torch.long)
        }

    def convert_labels_to_ids(self, label_list):
        ids = []
        for label in label_list:
            if label not in self.label_map:
                self.label_map[label] = self.num_labels
                self.num_labels += 1
            ids.append(self.label_map[label])
        return ids

    def get_num_labels(self):
        return self.num_labels

In [12]:
def load_medmentions_dataset(file_path):
    df = pd.read_csv(file_path)
    texts = df['text'].tolist()
    labels = df['labels'].tolist()
    return texts, labels

In [13]:
def train_epoch(model, data_loader, optimizer, device, scheduler):
    model.train()
    total_loss = 0

    for batch in tqdm(data_loader, desc="Training"):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()

    return total_loss / len(data_loader)

In [14]:
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs.logits, dim=2)
            
            predictions.extend(preds[labels != -100].cpu().numpy())
            true_labels.extend(labels[labels != -100].cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')
    
    return accuracy, precision, recall, f1

In [15]:
# Hyperparameters
MAX_LEN = 128
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 32
EPOCHS = 5
LEARNING_RATE = 2e-5

In [16]:
# Load datasets
train_texts, train_labels = load_medmentions_dataset('./data/medmentions_train.csv')
val_texts, val_labels = load_medmentions_dataset('./data/medmentions_val.csv')

In [19]:
# Prepare datasets and initialize tokenizers
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

train_dataset = MedMentionsDataset(train_texts, train_labels, tokenizer, MAX_LEN)
val_dataset = MedMentionsDataset(val_texts, val_labels, tokenizer, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=VALID_BATCH_SIZE, shuffle=False)



In [20]:
# Initialize tokenizer and model
num_labels = train_dataset.get_num_labels()
model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=num_labels)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
# Setup training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)



In [None]:
# Training loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train_loss = train_epoch(model, train_loader, optimizer, device, scheduler)
    print(f"Training loss: {train_loss}")
        
    accuracy, precision, recall, f1 = evaluate(model, val_loader, device)
    print(f"Validation Accuracy: {accuracy:.4f}")
    print(f"Validation Precision: {precision:.4f}")
    print(f"Validation Recall: {recall:.4f}")
    print(f"Validation F1-score: {f1:.4f}")

In [None]:
# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss}")
    # add validation

In [None]:
# Save the model
model.save_pretrained("./medical_ner_model")
tokenizer.save_pretrained("./medical_ner_model")