
#....Below is a  Python code that integrates hyperbolic embeddings
with a BERT- based model for an interpretable ICD coding task using
PyTorch, Hugging Face's Transformers, and Poincaré....#

In [1]:
Pip! install transformers

SyntaxError: invalid syntax (<ipython-input-1-cc8399165827>, line 1)

In [2]:
Pip! install Poincare

SyntaxError: invalid syntax (<ipython-input-2-ab8855e126d8>, line 1)

In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
from poincare import PoincareModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score

ModuleNotFoundError: No module named 'poincare'

In [None]:
# Step 1: Data Preprocessing
def preprocess_data(file_path):
    data = pd.read_csv(file_path)
    # Perform any necessary preprocessing steps (e.g., cleaning, tokenization)
    return data

In [None]:
# Step 2: Hyperbolic Embedding Training
def train_hyperbolic_embeddings(data):
    poincare_model = PoincareModel()
    poincare_model.train(data)
    return poincare_model

In [None]:
# Step 3: BERT-based Representation
def get_bert_embeddings(texts):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[:, 0, :]
    return embeddings

In [None]:
# Step 4: Integration and Multi-label Classification
class HyperBertClassifier(nn.Module):
    def __init__(self, bert_model, poincare_model, num_classes):
        super(HyperBertClassifier, self).__init__()
        self.bert_model = bert_model
        self.poincare_model = poincare_model
        self.linear = nn.Linear(bert_model.config.hidden_size + poincare_model.embedding_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        bert_outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        bert_embeddings = bert_outputs.last_hidden_state[:, 0, :]
        poincare_embeddings = self.poincare_model.get_embeddings() # Assuming you have a method to get embeddings
        combined_embeddings = torch.cat((bert_embeddings, poincare_embeddings), dim=1)
        logits = self.linear(combined_embeddings)
        return logits

In [None]:
# Step 5: Fine-tuning
def fine_tune_model(model, train_loader, criterion, optimizer, device):
    model.train()
    for batch in train_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(*inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

In [None]:
# Step 6: Evaluation
def evaluate_model(model, test_loader, device):
    model.eval()
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(*inputs)
            predictions = torch.sigmoid(outputs) > 0.5
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return all_labels, all_predictions

In [None]:
# Step 7: Interpretability Techniques
def visualize_attention(model, tokenizer, text):
    inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**inputs)
    attention_weights = outputs.attentions
    # Visualize attention weights

In [None]:
# Step 8: Main Function
def main():
    # Load and preprocess data
    data = preprocess_data('mimic_iv_top50.csv')
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

    # Train hyperbolic embeddings
    poincare_model = train_hyperbolic_embeddings(train_data)

    # Get BERT embeddings
    train_texts = train_data['discharge_summary'].tolist()
    test_texts = test_data['discharge_summary'].tolist()
    train_embeddings = get_bert_embeddings(train_texts)
    test_embeddings = get_bert_embeddings(test_texts)

    # Define and train the classifier
    num_classes = len(data['ICD_code'].unique())
    classifier = HyperBertClassifier(BertModel.from_pretrained('bert-base-uncased'), poincare_model, num_classes)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=1e-5)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_loader = torch.utils.data.DataLoader(train_embeddings, batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_embeddings, batch_size=32, shuffle=False)

    for epoch in range(5):
        fine_tune_model(classifier, train_loader, criterion, optimizer, device)

    # Evaluation
    true_labels, predicted_labels = evaluate_model(classifier, test_loader, device)
    macro_f1 = f1_score(true_labels, predicted_labels, average='macro')
    micro_f1 = f1_score(true_labels, predicted_labels, average='micro')
    auc = roc_auc_score(true_labels, predicted_labels)

    print("Macro F1:", macro_f1)
    print("Micro F1:", micro_f1)
    print("AUC:", auc)

    # Interpretability
    visualize_attention(classifier.bert_model, tokenizer, train_texts[0])

if __name__ == "__main__":
    main()