<a href="https://colab.research.google.com/github/derejeweyessaa/HyperBERT-/blob/main/HyperBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

#....Below is a  Python code that integrates hyperbolic embeddings
with a BERT-based model for an interpretable ICD coding task using
PyTorch, Hugging Face's Transformers, and Poincaré....#

In [None]:
#.......To integrate hyperbolic embeddings and BERT-based representations for the automatic ICD coding task, we need to perform the following steps:

1. Preprocess the data: Prepare discharge summaries and associated ICD codes.


2. Train hyperbolic embeddings on the hierarchical structure of ICD codes using Poincaré embeddings.


3. Fine-tune a pre-trained BERT model on the discharge summaries for the downstream task of ICD coding.along with multi label classification


4. Integrate hyperbolic embeddings and BERT-based representations by combining them in a joint model.


5. Train the joint model on the training data.


6. Evaluate the performance of the joint model on a separate validation or test dataset.
7. Visualization for interpretability

...#

In [None]:
Pip! Install Transformers

In [None]:
Pip! Install torch

In [None]:
Pip! Install PoincareModel

In [None]:
Pip! Install geomstats

In [None]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
from poincare_embedding import PoincareModel
from geomstats.optimization.optimizers import RiemannianAdam
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# Step 1: Data Preprocessing
def preprocess_mimic_iv(data_path):
    mimic_data = pd.read_csv(data_path)
    mimic_data = mimic_data[['discharge_summary', 'icd_codes']]  # Filter relevant columns
    mimic_data.dropna(inplace=True)  # Drop rows with missing values
    train_data, val_data, test_data = split_data(mimic_data)  # Split the data into train, validation, and test sets
    return train_data, val_data, test_data

In [None]:
# Step 2: Hyperbolic Embedding Training
def train_hyperbolic_embedding(icd_hierarchy):
    poincare_model = PoincareModel()
    poincare_optimizer = RiemannianAdam(poincare_model.parameters(), lr=0.001)

    num_epochs_poincare = 5
    for epoch in range(num_epochs_poincare):
        poincare_model.train()
        for icd_codes in icd_hierarchy:
            poincare_optimizer.zero_grad()
            poincare_loss = poincare_model.loss(icd_codes)
            poincare_loss.backward()
            poincare_optimizer.step()

In [None]:
Define multi-label classification model
class MultiLabelClassifier(nn.Module):
    def __init__(self, bert_model, num_labels):
        super(MultiLabelClassifier, self).__init__()
        self.bert_model = bert_model
        self.fc = nn.Linear(bert_model.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # Use the [CLS] token representation
        logits = self.fc(pooled_output)
        return logits

In [None]:


# Step 3: BERT-based Representation
def get_bert_embeddings(texts):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[:, 0, :]
    return embeddings

# Step 4: Integration and Multi-label Classification
class HyperBertClassifier(nn.Module):
    def __init__(self, bert_model, poincare_model, num_classes):
        super(HyperBertClassifier, self).__init__()
        self.bert_model = bert_model
        self.poincare_model = poincare_model
        self.linear = nn.Linear(bert_model.config.hidden_size + poincare_model.embedding_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        bert_outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        bert_embeddings = bert_outputs.last_hidden_state[:, 0, :]
        poincare_embeddings = self.poincare_model.get_embeddings()  # Assuming you have a method to get embeddings
        combined_embeddings = torch.cat((bert_embeddings, poincare_embeddings), dim=1)
        logits = self.linear(combined_embeddings)
        return logits

In [None]:
Step 4: Integration and Multi-label Classification
class HyperBertClassifier(nn.Module):
    def __init__(self, bert_model, poincare_model, num_classes):
        super(HyperBertClassifier, self).__init__()
        self.bert_model = bert_model
        self.poincare_model = poincare_model
        self.linear = nn.Linear(bert_model.config.hidden_size + poincare_model.embedding_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        bert_outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        bert_embeddings = bert_outputs.last_hidden_state[:, 0, :]
        poincare_embeddings = self.poincare_model.get_embeddings()  # Assuming you have a method to get embeddings
        combined_embeddings = torch.cat((bert_embeddings, poincare_embeddings), dim=1)
        logits = self.linear(combined_embeddings)
        return logits

In [None]:


# Step 5: Model Architecture
class HyperBertModel(nn.Module):
    def __init__(self, bert_model):
        super(HyperBertModel, self).__init__()
        self.bert_model = bert_model
        self.fc = nn.Linear(bert_model.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.fc(bert_output.pooler_output)
        return logits

In [None]:
# Step 6: Model Training
def train_model(model, train_loader, criterion, optimizer, num_epochs, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * input_ids.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

In [None]:
# Step 7: Evaluation
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    return all_labels, all_preds

In [None]:
# Evaluation option
test_loader = DataLoader(test_dataset, batch_size=batch_size)
predictions, true_labels = evaluate_model(classifier_model, test_loader, device)

# Calculate F1-score
f1_micro = f1_score(true_labels, (predictions >= 0.5).astype(int), average='micro')
f1_macro = f1_score(true_labels, (predictions >= 0.5).astype(int), average='macro')

# Compute AUC score
auc_score = roc_auc_score(true_labels, predictions)

print(f'Micro F1-score: {f1_micro}')
print(f'Macro F1-score: {f1_macro}')
print(f'AUC Score: {auc_score}')

# Compute document-code similarity
# Assuming bert_embeddings and hyperbolic_embeddings are computed
similarity_scores = compute_similarity(bert_embeddings, hyperbolic_embeddings)

# Utilize model interpretability techniques (e.g., attention and visualization) for code-aware document representations
# Code for attention visualization and interpretation....#

In [None]:
 # Step 8: Fine-tuning Process
def fine_tune_model(model, fine_tune_loader, criterion, optimizer, num_epochs, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in fine_tune_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * input_ids.size(0)

        epoch_loss = running_loss / len(fine_tune_loader.dataset)
        print(f"Fine-tuning Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")


In [None]:
# Step 9: Document-Code Similarity Prediction (if needed)
def compute_similarity(document_embedding, code_embedding):
    # Compute similarity in hyperbolic space
    similarity_score = ...
    return similarity_score

In [None]:
# Step 10: Code-wise Label Attention Visualization
def visualize_attention(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors='pt')
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    with torch.no_grad():
        outputs = model.bert_model(input_ids, attention_mask)
        attentions = outputs['attentions'][-1]  # Get attention weights from the last layer
    # Implement your attention visualization code here
    return attentions

In [None]:
# Main code
# Assuming train_loader, fine_tune_loader, and test_loader are DataLoader objects containing training, fine-tuning, and test data respectively
train_hyperbolic_embedding(icd_hierarchy)
train_data, _, _ = preprocess_mimic_iv('your_dataset.csv')
train_dataset = MultiLabelDataset(train_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HyperBertModel(bert_model).to(device)
optimizer = RiemannianAdam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

num_epochs = 5
train_model(model, train_loader, criterion, optimizer, num_epochs, device)

# After training, fine-tune the model for ICD code hierarchy
fine_tune_model(model, fine_tune_loader, criterion, optimizer, num_epochs, device)

# Evaluate the model
test_data, _, _ = preprocess_mimic_iv('your_test_dataset.csv')
test_dataset = MultiLabelDataset(test_data, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
labels, preds = evaluate_model(model, test_loader, device)

# Compute document-code similarity predictions and visualize code-wise label attention
text = "Your discharge summary text"
attentions = visualize_attention(model, tokenizer, text)