<a href="https://colab.research.google.com/github/gyasifred/clinical-valence-testing/blob/main/clinical_valence_testing_train_script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading x

In [None]:
import torch
import pandas as pd
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AdamW,
    get_scheduler
)
from sklearn.metrics import accuracy_score, f1_score, classification_report
from typing import Dict, List, Tuple
from datasets import load_dataset, Dataset as HFDataset

class TransformerDataset(Dataset):
    """Dataset class for Hugging Face transformer models."""

    def __init__(self, encodings, labels=None):
        """
        Args:
            encodings: The encoded inputs from tokenizer
            labels: Optional list of labels (multi-hot encoded)
        """
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels is not None:
            item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)  # Float for BCE loss
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

def train_one_epoch(
    model: nn.Module,
    train_loader: DataLoader,
    optimizer,
    lr_scheduler,
    device: str,
    threshold: float = 0.5
) -> Tuple[float, float, float]:
    """
    Train the model for one epoch.

    Args:
        model: The transformer model.
        train_loader: DataLoader for training data.
        optimizer: Optimizer.
        lr_scheduler: Learning rate scheduler.
        device: Device string ('cuda' or 'cpu').
        threshold: Threshold for binary prediction.

    Returns:
        Tuple of (average training loss, training accuracy, f1 score) for the epoch.
    """
    model.train()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()

        # Calculate predictions using sigmoid for multi-label
        logits = outputs.logits
        predictions = (torch.sigmoid(logits) > threshold).int()
        all_preds.append(predictions.detach().cpu().numpy())
        all_labels.append(batch['labels'].detach().cpu().numpy())

    avg_loss = total_loss / len(train_loader)

    # Convert lists to numpy arrays for metrics calculation
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    # Calculate metrics (sample-wise)
    accuracy = np.mean(np.all(all_preds == all_labels, axis=1))
    f1 = f1_score(all_labels, all_preds, average='samples', zero_division=0)

    # Return training metrics
    return avg_loss, accuracy, f1

def evaluate_model(
    model: nn.Module,
    val_loader: DataLoader,
    device: str,
    threshold: float = 0.5,
    detailed_report: bool = False
) -> Tuple[float, float, float]:
    """
    Evaluate the model on the validation set.

    Args:
        model: The transformer model.
        val_loader: DataLoader for validation data.
        device: Device string ('cuda' or 'cpu').
        threshold: Threshold for binary prediction.
        detailed_report: Whether to print a detailed classification report.

    Returns:
        Tuple of (validation loss, validation accuracy, validation f1 score).
    """
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0.0

    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.item()

            # Calculate predictions using sigmoid for multi-label
            logits = outputs.logits
            predictions = (torch.sigmoid(logits) > threshold).int()
            all_preds.append(predictions.detach().cpu().numpy())
            all_labels.append(batch['labels'].detach().cpu().numpy())

    avg_loss = total_loss / len(val_loader)

    # Convert lists to numpy arrays for metrics calculation
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    # Calculate metrics (sample-wise)
    accuracy = np.mean(np.all(all_preds == all_labels, axis=1))
    f1 = f1_score(all_labels, all_preds, average='samples', zero_division=0)

    # Print detailed metrics if requested
    if detailed_report:
        print("\nDetailed classification report:")
        print(classification_report(all_labels, all_preds, target_names=[f"Code_{i}" for i in range(all_labels.shape[1])], zero_division=0))

    return avg_loss, accuracy, f1

def codes_that_occur_n_times_in_dataset(n: int, train_data: pd.DataFrame, label_column: str = "short_codes"):
    """
    Get codes that appear at least n times in the dataset.

    Args:
        n: Minimum number of occurrences.
        train_data: DataFrame containing the training data.
        label_column: Name of the column containing codes.

    Returns:
        List of codes that appear at least n times.
    """
    df = train_data.copy()
    code_count = {}
    for i, row in df.iterrows():
        codes = row[label_column].split(",")
        for code in codes:
            code = code.strip()  # Remove any whitespace
            if code:  # Only count non-empty codes
                if code in code_count:
                    code_count[code] += 1
                else:
                    code_count[code] = 1

    del df
    # Get codes that appear at least n times
    frequent_codes = [code for code, count in code_count.items() if count >= n]

    # Sort by frequency (most frequent first)
    frequent_codes = sorted(frequent_codes, key=lambda x: code_count[x], reverse=True)

    print(f"Found {len(frequent_codes)} codes that appear at least {n} times")
    return frequent_codes

def multihot_encode(codes_str: str, code_list: List[str]) -> List[int]:
    """
    Convert a comma-separated string of codes to a multi-hot encoded vector.

    Args:
        codes_str: Comma-separated string of codes.
        code_list: List of all possible codes.

    Returns:
        Multi-hot encoded vector (list of 0s and 1s).
    """
    codes = codes_str.split(",")
    codes = [code.strip() for code in codes]  #
    return [1 if code in codes else 0 for code in code_list]

def load_csv_files_as_hf_dataset(
    train_csv_path: str,
    val_csv_path: str,
    test_csv_path: str,
    text_column: str,
    label_column: str = "short_codes",
    min_occurrences: int = 100,
    seed: int = 42
) -> Tuple[HFDataset, HFDataset, HFDataset, List[str]]:
    """
    Load CSV files as Hugging Face datasets for multi-label classification.

    Args:
        train_csv_path: Path to training CSV file.
        val_csv_path: Path to validation CSV file.
        test_csv_path: Path to test CSV file.
        text_column: Name of the column containing text data.
        label_column: Name of the column containing labels.
        min_occurrences: Minimum number of occurrences for a code to be included.
        seed: Random seed for reproducibility.

    Returns:
        train_dataset: Training dataset.
        val_dataset: Validation dataset.
        test_dataset: Test dataset.
        frequent_codes: List of codes that appear at least min_occurrences times.
    """
    # Load the datasets
    print("Loading datasets...")
    train_df = pd.read_csv(train_csv_path)
    val_df = pd.read_csv(val_csv_path)
    test_df = pd.read_csv(test_csv_path)

    # Get frequent codes from training data only
    print("Identifying frequent codes from training data...")
    frequent_codes = codes_that_occur_n_times_in_dataset(
        n=min_occurrences,
        train_data=train_df,
        label_column=label_column
    )

    # Create multi-hot encoded labels
    print("Creating multi-hot encoded labels...")
    train_df['labels'] = train_df[label_column].apply(lambda x: multihot_encode(x, frequent_codes))
    val_df['labels'] = val_df[label_column].apply(lambda x: multihot_encode(x, frequent_codes))
    test_df['labels'] = test_df[label_column].apply(lambda x: multihot_encode(x, frequent_codes))

    # Filter out rows with no labels (no codes in the frequent list)
    train_df = train_df[train_df['labels'].apply(lambda x: sum(x) > 0)]
    val_df = val_df[val_df['labels'].apply(lambda x: sum(x) > 0)]
    test_df = test_df[test_df['labels'].apply(lambda x: sum(x) > 0)]

    # Convert to Hugging Face datasets
    train_dataset = HFDataset.from_pandas(train_df)
    val_dataset = HFDataset.from_pandas(val_df)
    test_dataset = HFDataset.from_pandas(test_df)

    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    print(f"Number of codes to predict: {len(frequent_codes)}")

    return train_dataset, val_dataset, test_dataset, frequent_codes


def train_transformer_multilabel(
    train_dataset: HFDataset,
    val_dataset: HFDataset,
    text_column: str,
    label_column: str,
    config: Dict,
    num_epochs: int,
    model_name: str = "bvanaken/CORe-clinical-outcome-biobert-v1",
    num_labels: int = 3,
    frequent_codes: List[str] = None
) -> Tuple[nn.Module, AutoTokenizer, Dict]:
    """
    End-to-end training function for multi-label classification with transformers.

    Args:
        train_dataset: Training dataset (Hugging Face dataset format).
        val_dataset: Validation dataset (Hugging Face dataset format).
        text_column: Name of the column containing text data.
        label_column: Name of the column containing labels (multi-hot vectors).
        config: Dictionary with hyperparameters.
        num_epochs: Number of training epochs.
        model_name: Hugging Face model name.
        num_labels: Number of possible labels.
        frequent_codes: List of the code names corresponding to each label position.

    Returns:
        model: Trained transformer model.
        tokenizer: Hugging Face tokenizer.
        metrics: Dictionary with training metrics.
    """
    # Initialize device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
        problem_type="multi_label_classification"
    ).to(device)

    # Extract texts and labels
    train_texts = train_dataset[text_column]
    train_labels = train_dataset[label_column]
    val_texts = val_dataset[text_column]
    val_labels = val_dataset[label_column]

    # Tokenize data
    train_encodings = tokenizer(
        train_texts,
        truncation=True,
        padding=True,
        max_length=config.get("max_length", 512)
    )
    val_encodings = tokenizer(
        val_texts,
        truncation=True,
        padding=True,
        max_length=config.get("max_length", 512)
    )

    # Create datasets
    train_torch_dataset = TransformerDataset(train_encodings, train_labels)
    val_torch_dataset = TransformerDataset(val_encodings, val_labels)

    # Create dataloaders
    train_loader = DataLoader(
        train_torch_dataset,
        batch_size=config.get("batch_size", 16),
        shuffle=True
    )
    val_loader = DataLoader(
        val_torch_dataset,
        batch_size=config.get("batch_size", 16)
    )

    # Initialize optimizer
    optimizer = AdamW(
        model.parameters(),
        lr=config.get("lr", 5e-5),
        weight_decay=config.get("weight_decay", 0.01)
    )

    # Initialize learning rate scheduler
    num_training_steps = num_epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name=config.get("scheduler", "linear"),
        optimizer=optimizer,
        num_warmup_steps=int(num_training_steps * config.get("warmup_steps", 0.1)) if isinstance(config.get("warmup_steps"), float) else config.get("warmup_steps", 5),
        num_training_steps=num_training_steps
    )

    # Initialize metrics tracking
    metrics = {
        "train_loss": [],
        "train_accuracy": [],
        "train_f1": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_f1": []
    }

    # Training loop
    best_val_loss = float('inf')
    best_model_state = None
    threshold = config.get("threshold", 0.5)

    print(f"Starting training for {num_epochs} epochs...")
    print(f"Multi-label classification with {num_labels} possible labels")

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 40)

        # Train for one epoch
        train_loss, train_accuracy, train_f1 = train_one_epoch(
            model,
            train_loader,
            optimizer,
            lr_scheduler,
            device,
            threshold
        )

        # Evaluate model (detailed report on last epoch)
        val_loss, val_accuracy, val_f1 = evaluate_model(
            model,
            val_loader,
            device,
            threshold,
            detailed_report=(epoch == num_epochs - 1)
        )

        # Save model if it's the best so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            print(f"New best model saved! (Val Loss: {val_loss:.4f})")

        # Update metrics
        metrics["train_loss"].append(train_loss)
        metrics["train_accuracy"].append(train_accuracy)
        metrics["train_f1"].append(train_f1)
        metrics["val_loss"].append(val_loss)
        metrics["val_accuracy"].append(val_accuracy)
        metrics["val_f1"].append(val_f1)

        print(
            f"Epoch {epoch+1}/{num_epochs} - "
            f"Train Loss: {train_loss:.4f}, "
            f"Train Accuracy: {train_accuracy:.4f}, "
            f"Train F1: {train_f1:.4f}, "
            f"Val Loss: {val_loss:.4f}, "
            f"Val Accuracy: {val_accuracy:.4f}, "
            f"Val F1: {val_f1:.4f}"
        )

    # Load the best model state
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("Loaded best model state from training")

    return model, tokenizer, metrics


def predict_with_model(
    model: nn.Module,
    tokenizer: AutoTokenizer,
    texts: List[str],
    device: str,
    threshold: float = 0.5,
    frequent_codes: List[str] = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Make predictions with a trained model.

    Args:
        model: Trained model.
        tokenizer: Tokenizer.
        texts: List of texts to predict.
        device: Device string ('cuda' or 'cpu').
        threshold: Threshold for binary prediction.
        frequent_codes: List of codes for reference.

    Returns:
        Tuple of (predicted probabilities, binary predictions).
    """
    model.eval()
    encodings = tokenizer(
        texts,
        truncation=True,
        padding=True,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**encodings)
        logits = outputs.logits
        probs = torch.sigmoid(logits)
        predictions = (probs > threshold).int()

    probs_np = probs.cpu().numpy()
    predictions_np = predictions.cpu().numpy()

    # Print predictions with code names if available
    if frequent_codes is not None:
        print("\nPredictions:")
        for i, text in enumerate(texts):
            print(f"\nText {i+1}: {text[:100]}...")

            # Get predicted codes
            pred_indices = np.where(predictions_np[i] == 1)[0]
            pred_codes = [frequent_codes[idx] for idx in pred_indices]

            if len(pred_codes) > 0:
                print(f"Predicted codes: {', '.join(pred_codes)}")
            else:
                print("No codes predicted.")

    return probs_np, predictions_np


def main():
    """Main function to run the multi-label ICD code classification."""
    # 1. Set up your data paths
    train_csv_path = "/content/drive/MyDrive/data/DIA_GROUPS_3_DIGITS_adm_test.csv"
    val_csv_path = "/content/drive/MyDrive/data/DIA_GROUPS_3_DIGITS_adm_val.csv"
    test_csv_path = "/content/drive/MyDrive/data/DIA_GROUPS_3_DIGITS_adm_val.csv"

    # 2. Configure column names
    text_column = "text"
    label_column = "short_codes"

    # 3. Set the minimum frequency threshold
    min_occurrences = 100

    # 4. Configuration for the model
    config = {
        "batch_size": 4,
        "lr": 2e-5,
        "max_length": 512,
        "weight_decay": 0.01,
        "scheduler": "linear",
        "warmup_steps": 0.1,
        "threshold": 0.5
    }

    # 5. Load and prepare datasets
    train_dataset, val_dataset, test_dataset, frequent_codes = load_csv_files_as_hf_dataset(
        train_csv_path=train_csv_path,
        val_csv_path=val_csv_path,
        test_csv_path=test_csv_path,
        text_column=text_column,
        label_column=label_column,
        min_occurrences=min_occurrences
    )

    # 6. Train the model
    num_labels = len(frequent_codes)
    num_epochs = 5
    model_name = "bvanaken/CORe-clinical-outcome-biobert-v1"

    model, tokenizer, metrics = train_transformer_multilabel(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        text_column=text_column,
        label_column="labels",
        config=config,
        num_epochs=num_epochs,
        model_name=model_name,
        num_labels=num_labels,
        frequent_codes=frequent_codes
    )

    # 7. Test on the test set
    test_texts = test_dataset[text_column]
    test_labels = test_dataset["labels"]

    # Create test dataloader
    test_encodings = tokenizer(
        test_texts,
        truncation=True,
        padding=True,
        max_length=config["max_length"]
    )
    test_torch_dataset = TransformerDataset(test_encodings, test_labels)
    test_loader = DataLoader(test_torch_dataset, batch_size=config["batch_size"])

    # Evaluate on test set
    print("\nEvaluating on test set:")
    test_loss, test_accuracy, test_f1 = evaluate_model(
        model,
        test_loader,
        device="cuda" if torch.cuda.is_available() else "cpu",
        threshold=config["threshold"],
        detailed_report=True
    )

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Test F1 Score: {test_f1:.4f}")

    # 8. Example prediction
    example_texts = [
        "Patient presented with shortness of breath and chest pain. Diagnosed with acute myocardial infarction.",
        "Chronic obstructive pulmonary disease with acute exacerbation. Patient has a history of smoking."
    ]

    print("\nExample predictions:")
    probs, preds = predict_with_model(
        model=model,
        tokenizer=tokenizer,
        texts=example_texts,
        device="cuda" if torch.cuda.is_available() else "cpu",
        threshold=config["threshold"],
        frequent_codes=frequent_codes
    )

    print("\nTraining complete!")


if __name__ == "__main__":
    main()

Loading datasets...
Identifying frequent codes from training data...
Found 187 codes that appear at least 100 times
Creating multi-hot encoded labels...
Training samples: 9790
Validation samples: 4899
Test samples: 4899
Number of codes to predict: 187
Using device: cuda


config.json:   0%|          | 0.00/428 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bvanaken/CORe-clinical-outcome-biobert-v1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training for 5 epochs...
Multi-label classification with 187 possible labels

Epoch 1/5
----------------------------------------
New best model saved! (Val Loss: 0.1703)
Epoch 1/5 - Train Loss: 0.2529, Train Accuracy: 0.0000, Train F1: 0.0199, Val Loss: 0.1703, Val Accuracy: 0.0000, Val F1: 0.0009

Epoch 2/5
----------------------------------------
New best model saved! (Val Loss: 0.1583)
Epoch 2/5 - Train Loss: 0.1671, Train Accuracy: 0.0004, Train F1: 0.1032, Val Loss: 0.1583, Val Accuracy: 0.0010, Val F1: 0.1858

Epoch 3/5
----------------------------------------
