In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from torch.optim import Adam
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, precision_score, recall_score, average_precision_score
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter

# 1. Load dataset and filter rare labels
print("Loading dataset...")
ds = load_dataset("rntc/mimic-icd-visit", split='train')
df = ds.to_pandas()
df = df.sample(n=17000, random_state=42).reset_index(drop=True)

label_counts = Counter(code for codes in df['icd_code'] for code in codes)
min_count = 10
common_labels = {label for label, count in label_counts.items() if count >= min_count}
df['icd_code'] = df['icd_code'].apply(lambda codes: [code for code in codes if code in common_labels])
df = df[df['icd_code'].map(len) > 0].reset_index(drop=True)

mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['icd_code'])
classes = mlb.classes_
print(f"Number of labels after filtering: {len(classes)}")

train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['cleaned_text'].tolist(), y, test_size=0.2, random_state=42
)

# 2. Tokenizer and Dataset
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
vocab = tokenizer.get_vocab()
word2idx = {word: idx for word, idx in vocab.items()}

MAX_LEN = 256
BATCH_SIZE = 32

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=MAX_LEN):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        tokens = self.tokenizer.tokenize(self.texts[idx])[:self.max_len]
        ids = self.tokenizer.convert_tokens_to_ids(tokens)
        pad_len = self.max_len - len(ids)
        input_ids = torch.tensor(ids + [0]*pad_len, dtype=torch.long)
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        return input_ids, label

train_dataset = TextDataset(train_texts, train_labels, tokenizer)
val_dataset = TextDataset(val_texts, val_labels, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# 3. Model with increased capacity
class LSTMMultiLabel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_labels, n_layers=3, bidirectional=True):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True,
                            num_layers=n_layers, bidirectional=bidirectional)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, num_labels)

    def forward(self, x):
        x = self.embedding(x)
        _, (h_n, _) = self.lstm(x)
        if self.lstm.bidirectional:
            h_n = torch.cat((h_n[-2], h_n[-1]), dim=1)
        else:
            h_n = h_n[-1]
        h_n = self.dropout(h_n)
        return self.classifier(h_n)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMMultiLabel(len(word2idx), embed_dim=300, hidden_dim=512, num_labels=len(classes)).to(device)

# 4. Focal Loss implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            BCE_loss = nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

criterion = FocalLoss()
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

from sklearn.metrics import f1_score, precision_score, recall_score

best_f1 = 0
for epoch in range(9):  # epochs
    model.train()
    total_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    model.eval()
    all_preds, all_true = [], []
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            outputs = torch.sigmoid(model(inputs)).cpu().numpy()
            all_preds.append(outputs)
            all_true.append(labels.numpy())

    all_preds = np.vstack(all_preds)
    all_true = np.vstack(all_true)

    # Tune threshold between 0.3 and 0.7 for best micro F1
    best_threshold = 0.5
    best_val_f1_micro = 0
    for t in np.linspace(0.3, 0.7, 9):
        preds_bin = (all_preds > t).astype(int)
        f1_micro = f1_score(all_true, preds_bin, average='micro', zero_division=0)
        if f1_micro > best_val_f1_micro:
            best_val_f1_micro = f1_micro
            best_threshold = t

    final_preds = (all_preds > best_threshold).astype(int)

    # Metrics
    f1_micro = f1_score(all_true, final_preds, average='micro', zero_division=0)
    recall_micro = recall_score(all_true, final_preds, average='micro', zero_division=0)
    precision_micro = precision_score(all_true, final_preds, average='micro', zero_division=0)

    f1_macro = f1_score(all_true, final_preds, average='macro', zero_division=0)
    recall_macro = recall_score(all_true, final_preds, average='macro', zero_division=0)
    precision_macro = precision_score(all_true, final_preds, average='macro', zero_division=0)

    f1_weighted = f1_score(all_true, final_preds, average='weighted', zero_division=0)
    recall_weighted = recall_score(all_true, final_preds, average='weighted', zero_division=0)
    precision_weighted = precision_score(all_true, final_preds, average='weighted', zero_division=0)

    # accuracy = (final_preds == all_true).all(axis=1).mean()

    # Top-k accuracy (top_k=10)
    top_k = 10
    topk_indices = np.argsort(all_preds, axis=1)[:, -top_k:]
    correct_topk = 0
    for i in range(all_true.shape[0]):
        true_indices = set(np.where(all_true[i] == 1)[0])
        pred_indices = set(topk_indices[i])
        if len(true_indices.intersection(pred_indices)) > 0:
            correct_topk += 1
    topk_accuracy = correct_topk / all_true.shape[0]

    avg_loss = total_loss / len(train_loader)

    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f} - "
          f"F1 Micro: {f1_micro:.4f}, Macro: {f1_macro:.4f}, average: {f1_weighted:.4f} - \n"
          f"Recall Micro: {recall_micro:.4f}, Macro: {recall_macro:.4f}, average: {recall_weighted:.4f} -\n "
          f"Precision Micro: {precision_micro:.4f}, Macro: {precision_macro:.4f}, average: {precision_weighted:.4f} - "
          f"Top-{top_k} Accuracy: {topk_accuracy:.4f}  "
          )

    scheduler.step(avg_loss)

    if best_val_f1_micro > best_f1:
        best_f1 = best_val_f1_micro
        torch.save(model.state_dict(), "best_lstm_model.pt")

  from .autonotebook import tqdm as notebook_tqdm


Loading dataset...
Number of labels after filtering: 275


Epoch 1 Training: 100%|██████████| 25/25 [00:10<00:00,  2.31it/s]


Epoch 1 - Loss: 0.0575 - F1 Micro: 0.2064, Macro: 0.0134, average: 0.0874 - 
Recall Micro: 0.2123, Macro: 0.0339, average: 0.2123 -
 Precision Micro: 0.2008, Macro: 0.0109, average: 0.0576 - Top-10 Accuracy: 0.8477  


Epoch 2 Training: 100%|██████████| 25/25 [00:09<00:00,  2.57it/s]


Epoch 2 - Loss: 0.0401 - F1 Micro: 0.2187, Macro: 0.0105, average: 0.0837 - 
Recall Micro: 0.1886, Macro: 0.0251, average: 0.1886 -
 Precision Micro: 0.2602, Macro: 0.0069, average: 0.0551 - Top-10 Accuracy: 0.8528  


Epoch 3 Training: 100%|██████████| 25/25 [00:11<00:00,  2.26it/s]


Epoch 3 - Loss: 0.0396 - F1 Micro: 0.2180, Macro: 0.0153, average: 0.0982 - 
Recall Micro: 0.2192, Macro: 0.0352, average: 0.2192 -
 Precision Micro: 0.2168, Macro: 0.0113, average: 0.0684 - Top-10 Accuracy: 0.8579  


Epoch 4 Training: 100%|██████████| 25/25 [00:12<00:00,  2.01it/s]


Epoch 4 - Loss: 0.0392 - F1 Micro: 0.2145, Macro: 0.0175, average: 0.1032 - 
Recall Micro: 0.2133, Macro: 0.0343, average: 0.2133 -
 Precision Micro: 0.2156, Macro: 0.0152, average: 0.0783 - Top-10 Accuracy: 0.8477  


Epoch 5 Training: 100%|██████████| 25/25 [00:12<00:00,  1.97it/s]


Epoch 5 - Loss: 0.0382 - F1 Micro: 0.2155, Macro: 0.0151, average: 0.0952 - 
Recall Micro: 0.2026, Macro: 0.0300, average: 0.2026 -
 Precision Micro: 0.2302, Macro: 0.0207, average: 0.0849 - Top-10 Accuracy: 0.8629  


Epoch 6 Training: 100%|██████████| 25/25 [00:09<00:00,  2.54it/s]


Epoch 6 - Loss: 0.0374 - F1 Micro: 0.2104, Macro: 0.0174, average: 0.0983 - 
Recall Micro: 0.1891, Macro: 0.0292, average: 0.1891 -
 Precision Micro: 0.2370, Macro: 0.0229, average: 0.0887 - Top-10 Accuracy: 0.8376  


Epoch 7 Training: 100%|██████████| 25/25 [00:09<00:00,  2.54it/s]


Epoch 7 - Loss: 0.0367 - F1 Micro: 0.1982, Macro: 0.0320, average: 0.1342 - 
Recall Micro: 0.3111, Macro: 0.0751, average: 0.3111 -
 Precision Micro: 0.1454, Macro: 0.0230, average: 0.0920 - Top-10 Accuracy: 0.8426  


Epoch 8 Training: 100%|██████████| 25/25 [00:09<00:00,  2.58it/s]


Epoch 8 - Loss: 0.0356 - F1 Micro: 0.1950, Macro: 0.0205, average: 0.1035 - 
Recall Micro: 0.1795, Macro: 0.0308, average: 0.1795 -
 Precision Micro: 0.2136, Macro: 0.0197, average: 0.0831 - Top-10 Accuracy: 0.8426  


Epoch 9 Training: 100%|██████████| 25/25 [00:09<00:00,  2.56it/s]


Epoch 9 - Loss: 0.0345 - F1 Micro: 0.2062, Macro: 0.0273, average: 0.1173 - 
Recall Micro: 0.2117, Macro: 0.0415, average: 0.2117 -
 Precision Micro: 0.2010, Macro: 0.0366, average: 0.1074 - Top-10 Accuracy: 0.8528  


In [21]:
num_labels_training = 3813  # for example, the number from your checkpoint training
model = LSTMMultiLabel(
    vocab_size=len(word2idx),
    embed_dim=300,
    hidden_dim=512,
    num_labels=num_labels_training,
    n_layers=3,
    bidirectional=True
).to(device)


In [22]:
model.load_state_dict(torch.load('best_lstm_model.pt', map_location=device))
model.eval()


LSTMMultiLabel(
  (embedding): Embedding(28996, 300)
  (lstm): LSTM(300, 512, num_layers=3, batch_first=True, bidirectional=True)
  (dropout): Dropout(p=0.3, inplace=False)
  (classifier): Linear(in_features=1024, out_features=3813, bias=True)
)

In [26]:
import torch
import numpy as np

# Inference function
def predict_icd_codes(text, model, tokenizer, classes, max_len=256, threshold=0.5, top_k=5, device='cpu'):
    model.eval()
    tokens = tokenizer.tokenize(text)[:max_len]
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    pad_len = max_len - len(token_ids)
    input_ids = torch.tensor(token_ids + [0]*pad_len, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(input_ids)
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    above_threshold = [classes[i] for i in np.where(probs >= threshold)[0]]
    topk_indices = probs.argsort()[-top_k:][::-1]
    top_k_labels = [classes[i] for i in topk_indices]

    return {'above_threshold': above_threshold, 'top_k': top_k_labels}

# Example usage
sample_text = "Patient admitted with acute myocardial infarction and chest pain."

result = predict_icd_codes(sample_text, model, tokenizer, classes, max_len=256, threshold=0.5, top_k=5, device=device)

print("Labels above threshold:", result['above_threshold'])
print("Top-k labels:", result['top_k'])


Labels above threshold: ['A047', 'A0471', 'A0472', 'A0811', 'A09', 'A310', 'A408', 'A4101', 'A4102', 'A411', 'A4151', 'A4152', 'A4159', 'A419', 'A6000', 'A879', 'B009', 'B0229', 'B029', 'B181', 'B182', 'B1910', 'B258', 'B259', 'B351', 'B356', 'B369', 'B370', 'B3749', 'B377', 'B3789', 'B600', 'B952', 'B955', 'B9562', 'B957', 'B958', 'B9629', 'B966', 'B9681', 'B9789', 'C155', 'C159', 'C160', 'C189', 'C220', 'C221', 'C23', 'C251', 'C252', 'C3401', 'C3412', 'C342', 'C3491', 'C3492', 'C50912', 'C539', 'C55', 'C562', 'C641', 'C642', 'C649', 'C679', 'C710', 'C711', 'C719', 'C779', 'C7800', 'C7802', 'C782', 'C784', 'C785', 'C786', 'C787', 'C7889', 'C7911', 'C7951', 'C7B8', 'C8339', 'C8519', 'C8599', 'C9002', 'C9100', 'C9110', 'C9200', 'C9202', 'C9210', 'C946', 'D120', 'D123', 'D124', 'D125', 'D259', 'D270', 'D271', 'D3502', 'D45', 'D469', 'D473', 'D47Z1', 'D508', 'D539', 'D5700', 'D573', 'D589', 'D594', 'D599', 'D61810', 'D61818', 'D630', 'D631', 'D638', 'D6489', 'D649', 'D66', 'D680', 'D6859'