#  1. Install Required Libraries and Import Dependencies

In [None]:
# Install required libraries
!pip install transformers torch sentence-transformers fasttext scikit-learn matplotlib umap-learn

Collecting fasttext
  Downloading fasttext-0.9.3.tar.gz (73 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.4/73.4 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-

In [None]:
# Import basic dependencies
import pandas as pd
import numpy as np
import re
import os
import json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Progress bars for notebooks

# Import torch and transformers
import torch
from transformers import AutoTokenizer, AutoModel, BertForMaskedLM, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset

# Import sklearn for evaluation
from sklearn.metrics import silhouette_score, adjusted_rand_score
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score, f1_score

# 2. Load ICD-11 Data

In [None]:
# Load your ICD-11 data
icd11_df = pd.read_csv('icd11_data_raw.csv')

# Display basic information about the dataset
print(f"Dataset shape: {icd11_df.shape}")
print(f"Columns: {icd11_df.columns.tolist()}")
icd11_df.head()

Dataset shape: (28087, 19)
Columns: ['id', 'code', 'title', 'browser_url', 'class_kind', 'definition', 'parent', 'inclusions', 'foundation_children', 'foundation_child_references', 'index_terms', 'related_entities', 'full_text', 'children', 'postcoordination_scales', 'index_term_references', 'exclusions', 'exclusion_references', 'fully_specified_name']


Unnamed: 0,id,code,title,browser_url,class_kind,definition,parent,inclusions,foundation_children,foundation_child_references,index_terms,related_entities,full_text,children,postcoordination_scales,index_term_references,exclusions,exclusion_references,fully_specified_name
0,1937339080,1C22,Infections due to Chlamydia psittaci,https://icd.who.int/browse/2023-01/mms/en#1937...,category,Any condition caused by an infection with the ...,1127435854,Psittacosis; Ornithosis; Parrot fever,Pneumonia in chlamydia psittaci infection,Pneumonia in chlamydia psittaci infection: htt...,Infections due to Chlamydia psittaci; Psittaco...,1935107489,Infections due to Chlamydia psittaci Any condi...,,,,,,
1,1671640403,1F51.0,Gambiense trypanosomiasis,https://icd.who.int/browse/2023-01/mms/en#1671...,category,A disease caused by an infection with the prot...,875488052,West African sleeping sickness; Infection due ...,,,,1945127438,Gambiense trypanosomiasis A disease caused by ...,1842725899; other; unspecified,"{'axis_name': 'hasManifestation', 'required': ...",,,,
2,1528414070,1A07,Typhoid fever,https://icd.who.int/browse/2023-01/mms/en#1528...,category,A condition caused by an infection with the gr...,135352227,,,,,911707612,Typhoid fever A condition caused by an infecti...,364534567; other; unspecified,"{'axis_name': 'hasManifestation', 'required': ...",,,,
3,328097188,1A36.12,Cutaneous amoebiasis,https://icd.who.int/browse/2023-01/mms/en#3280...,category,,1777228366,,,,Cutaneous amoebiasis; Amoebiasis of skin; Amoe...,911707612,Cutaneous amoebiasis Cutaneous amoebiasis; Amo...,,,,,,
4,1483190070,1D03,Infectious abscess of the central nervous system,https://icd.who.int/browse/2023-01/mms/en#1483...,category,A focal suppurative process of the brain paren...,1585949804,,,,,911707612,Infectious abscess of the central nervous syst...,443087096; 613341872; 1147230459; 1128677700; ...,"{'axis_name': 'specificAnatomy', 'required': '...",,,,


# 3. Prepare Text for Transformer Models

In [None]:
def prepare_text_for_transformers(df):
    """
    Prepare text for transformer models - don't stem or remove stopwords
    as transformers handle these contextual elements better.
    """
    print("Preparing text data for transformer models...")

    # Combine relevant text fields
    df['transformer_text'] = df.apply(
        lambda row: ' '.join(filter(None, [
            str(row['title'] if pd.notna(row['title']) else ''),
            str(row['definition'] if pd.notna(row['definition']) else ''),
            str(row['inclusions'] if pd.notna(row['inclusions']) else ''),
            str(row['index_terms'] if pd.notna(row['index_terms']) else ''),
            str(row['full_text'] if pd.notna(row['full_text']) else '')
        ])), axis=1
    )

    # Clean text but preserve more linguistic information
    df['transformer_text'] = df['transformer_text'].apply(
        lambda text: re.sub('<.*?>', '', text)  # Remove HTML tags
    )

    # Print statistics about text length
    text_lengths = df['transformer_text'].str.len()
    print(f"Average text length: {text_lengths.mean():.1f} characters")
    print(f"Min text length: {text_lengths.min()} characters")
    print(f"Max text length: {text_lengths.max()} characters")

    return df

# Apply text preparation
icd11_df = prepare_text_for_transformers(icd11_df)

# Create a smaller subset for testing - this can help you debug without waiting for the full dataset
sample_size = 1000  # Adjust as needed
icd11_sample = icd11_df.sample(sample_size, random_state=42)
print(f"Created sample dataset with {len(icd11_sample)} entries")

Preparing text data for transformer models...
Average text length: 457.1 characters
Min text length: 9 characters
Max text length: 65045 characters
Created sample dataset with 1000 entries


# 4. Define Dataset and Embedding Generation Functions

In [None]:
class ICDTextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(texts, truncation=True, padding='max_length',
                                   max_length=max_length, return_tensors='pt')

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

def generate_bert_embeddings(texts, model_name_or_path, pooling='mean', batch_size=8):
    """
    Generate embeddings from a BERT model.
    Optimized for Google Colab with progress bar and memory management.

    Args:
        texts: List of text descriptions
        model_name_or_path: HuggingFace model name or path to fine-tuned model
        pooling: Strategy for pooling token embeddings ('mean', 'cls')
        batch_size: Number of texts to process at once

    Returns:
        Numpy array of embeddings
    """
    print(f"Generating embeddings using {model_name_or_path}...")

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModel.from_pretrained(model_name_or_path)

    # Use GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    model = model.to(device)
    model.eval()

    # Process in batches to handle memory constraints
    all_embeddings = []

    # Use tqdm for progress tracking
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
        batch_texts = texts[i:i+batch_size]

        # Tokenize
        encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
                                 max_length=512, return_tensors='pt')

        # Move to device
        encoded_input = {k: v.to(device) for k, v in encoded_input.items()}

        # Get model output
        with torch.no_grad():
            outputs = model(**encoded_input)

        # Get embeddings - last hidden states
        token_embeddings = outputs.last_hidden_state

        # Apply pooling strategy
        if pooling == 'cls':
            # Use [CLS] token embedding (first token)
            batch_embeddings = token_embeddings[:, 0, :].cpu().numpy()
        else:  # mean pooling
            # Create attention mask
            attention_mask = encoded_input['attention_mask']

            # Apply mask and calculate mean
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Prevent division by zero
            batch_embeddings = (sum_embeddings / sum_mask).cpu().numpy()

        all_embeddings.append(batch_embeddings)

        # Clear GPU memory if needed
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    # Concatenate all embeddings
    embeddings = np.vstack(all_embeddings)
    print(f"Generated embeddings shape: {embeddings.shape}")
    return embeddings

# 5.Generate Embeddings with a Single Model (Quick Test)

In [None]:
# Start with a small sample and one model to test the pipeline
test_texts = icd11_sample['transformer_text'].tolist()[:50]  # Start with just 50 examples

# Choose a smaller, faster model for initial testing
test_model = 'bert-base-uncased'  # You can switch to biomedical models later

# Generate test embeddings with a small batch size
test_embeddings = generate_bert_embeddings(
    test_texts,
    test_model,
    batch_size=4  # Small batch size for testing
)

print(f"Test successful! Generated embeddings with shape: {test_embeddings.shape}")

# Save these test embeddings to confirm file saving works
np.save('test_bert_embeddings.npy', test_embeddings)
print("Test embeddings saved successfully")

Generating embeddings using bert-base-uncased...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Using device: cuda


Processing batches:   0%|          | 0/13 [00:00<?, ?it/s]

Generated embeddings shape: (50, 768)
Test successful! Generated embeddings with shape: (50, 768)
Test embeddings saved successfully


# 6.Generate Embeddings with Biomedical BERT Models

In [None]:
# Now let's generate embeddings with a biomedical model
# You can choose which one to run based on your needs

# Define the models you want to use - uncomment the ones you want to run
models_to_run = {
    'bert': 'bert-base-uncased',  # Basic BERT (general language)
    'bioclinicalbert': 'emilyalsentzer/Bio_ClinicalBERT',  # Clinical BERT (medical focus)
    'biobert': 'dmis-lab/biobert-base-cased-v1.1',  # BioBERT (biomedical literature)
    'pubmedbert': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'  # PubMedBERT
}

# Choose the dataset size - start with the sample, gradually increase when confident
# texts_to_embed = icd11_sample['transformer_text'].tolist()  # Use the sample dataset
texts_to_embed = icd11_df['transformer_text'].tolist()  # Use the full dataset (may take a long time)

# Create a directory for saving embeddings
os.makedirs('embeddings', exist_ok=True)

# Generate and save embeddings for each model
for model_name, model_path in models_to_run.items():
    print(f"\nGenerating {model_name} embeddings...")

    embeddings = generate_bert_embeddings(
        texts_to_embed,
        model_path,
        batch_size=8  # Adjust based on your GPU memory
    )

    # Save the embeddings
    output_path = f"embeddings/{model_name}_embeddings.npy"
    np.save(output_path, embeddings)
    print(f"Saved {model_name} embeddings to {output_path}")

    # Also save the code list to ensure alignment
    if model_name == list(models_to_run.keys())[0]:  # Only save once
        codes = icd11_df['code'].tolist() if len(texts_to_embed) == len(icd11_df) else icd11_sample['code'].tolist()
        pd.Series(codes).to_csv('embeddings/code_list.csv', index=False)
        print("Saved code list for reference")


Generating bert embeddings...
Generating embeddings using bert-base-uncased...
Using device: cuda


Processing batches:   0%|          | 0/3511 [00:00<?, ?it/s]

Generated embeddings shape: (28087, 768)
Saved bert embeddings to embeddings/bert_embeddings.npy
Saved code list for reference

Generating bioclinicalbert embeddings...
Generating embeddings using emilyalsentzer/Bio_ClinicalBERT...


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Using device: cuda


Processing batches:   0%|          | 0/3511 [00:00<?, ?it/s]

Generated embeddings shape: (28087, 768)
Saved bioclinicalbert embeddings to embeddings/bioclinicalbert_embeddings.npy

Generating biobert embeddings...
Generating embeddings using dmis-lab/biobert-base-cased-v1.1...


config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Using device: cuda


Processing batches:   0%|          | 0/3511 [00:00<?, ?it/s]

Generated embeddings shape: (28087, 768)
Saved biobert embeddings to embeddings/biobert_embeddings.npy

Generating pubmedbert embeddings...
Generating embeddings using microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext...


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Using device: cuda


Processing batches:   0%|          | 0/3511 [00:00<?, ?it/s]

Generated embeddings shape: (28087, 768)
Saved pubmedbert embeddings to embeddings/pubmedbert_embeddings.npy


# 7. Evaluate Embeddings

In [None]:
def evaluate_embeddings(embeddings, labels, evaluation_tasks=['clustering', 'classification']):
    """
    Evaluate embeddings using various metrics.

    Args:
        embeddings: Array of embeddings
        labels: Array of labels (e.g., top-level ICD categories)
        evaluation_tasks: Tasks to evaluate on
    """
    results = {}

    # Make sure we have valid labels
    valid_indices = ~pd.isna(labels)
    if not all(valid_indices):
        print(f"Warning: {(~valid_indices).sum()} labels are NaN, filtering these out")
        embeddings = embeddings[valid_indices]
        labels = labels[valid_indices]

    unique_labels = np.unique(labels)
    print(f"Evaluating with {len(unique_labels)} unique label categories")

    if 'clustering' in evaluation_tasks:
        print("Performing clustering evaluation...")
        # Perform K-means clustering
        n_clusters = min(len(unique_labels), 100)  # Cap at 100 clusters for efficiency
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        clusters = kmeans.fit_predict(embeddings)

        # Calculate clustering metrics
        silhouette = silhouette_score(embeddings, clusters)
        rand_index = adjusted_rand_score(labels, clusters)

        results['clustering'] = {
            'silhouette': float(silhouette),
            'adjusted_rand_index': float(rand_index)
        }
        print(f"Clustering results: Silhouette={silhouette:.3f}, ARI={rand_index:.3f}")

    if 'classification' in evaluation_tasks:
        print("Performing classification evaluation...")
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            embeddings, labels, test_size=0.2, random_state=42)

        # Train classifier
        clf = LinearSVC(random_state=42)
        clf.fit(X_train, y_train)

        # Evaluate
        y_pred = clf.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        macro_f1 = f1_score(y_test, y_pred, average='macro')

        results['classification'] = {
            'accuracy': float(accuracy),
            'macro_f1': float(macro_f1)
        }
        print(f"Classification results: Accuracy={accuracy:.3f}, Macro-F1={macro_f1:.3f}")

    return results

# Create a directory for results
os.makedirs('results', exist_ok=True)

# First, load any embeddings you've generated
embedding_files = {}
for file in os.listdir('embeddings'):
    if file.endswith('_embeddings.npy'):
        model_name = file.replace('_embeddings.npy', '')
        embedding_files[model_name] = f'embeddings/{file}'

print(f"Found {len(embedding_files)} embedding files to evaluate: {list(embedding_files.keys())}")

# Load code list for labels
try:
    codes = pd.read_csv('embeddings/code_list.csv', header=0).iloc[:, 0].tolist()
    print(f"Loaded {len(codes)} codes from file")
except:
    # If code list is not saved, use the full dataset
    codes = icd11_df['code'].tolist()
    print(f"Using {len(codes)} codes from dataframe")

# Extract top-level categories from codes as labels
labels = pd.Series(codes).str.split('.').str[0]
print(f"Created {len(labels)} labels, with {labels.nunique()} unique values")

# Evaluate each embedding file
evaluation_results = {}
for model_name, file_path in embedding_files.items():
    print(f"\nEvaluating {model_name} embeddings...")
    embeddings = np.load(file_path)

    # Make sure embeddings and labels match in length
    if len(embeddings) != len(labels):
        print(f"Warning: Embeddings length ({len(embeddings)}) doesn't match labels ({len(labels)})")
        min_len = min(len(embeddings), len(labels))
        embeddings = embeddings[:min_len]
        labels_subset = labels[:min_len]
    else:
        labels_subset = labels

    try:
        eval_results = evaluate_embeddings(embeddings, labels_subset)
        evaluation_results[model_name] = eval_results
    except Exception as e:
        print(f"Error evaluating {model_name}: {e}")

# Save evaluation results
with open('results/evaluation_results.json', 'w') as f:
    json.dump(evaluation_results, f, indent=2)
print("Saved evaluation results")

Found 4 embedding files to evaluate: ['bert', 'pubmedbert', 'biobert', 'bioclinicalbert']
Loaded 28087 codes from file
Created 28087 labels, with 17776 unique values

Evaluating bert embeddings...
Evaluating with 17776 unique label categories
Performing clustering evaluation...
Clustering results: Silhouette=0.009, ARI=0.010
Performing classification evaluation...


# 8. Visualize Embeddings

In [None]:
def visualize_embeddings(embeddings, labels, method='tsne', sample_size=5000):
    """
    Visualize embeddings using dimensionality reduction.

    Args:
        embeddings: Array of embeddings
        labels: Array of labels (e.g., top-level ICD categories)
        method: Dimensionality reduction method ('tsne', 'umap')
        sample_size: Maximum number of points to visualize (for large datasets)
    """
    import matplotlib.pyplot as plt

    # Sample data if it's too large
    if len(embeddings) > sample_size:
        print(f"Sampling {sample_size} points from {len(embeddings)} total")
        indices = np.random.choice(len(embeddings), sample_size, replace=False)
        embeddings = embeddings[indices]
        labels = labels.iloc[indices] if hasattr(labels, 'iloc') else labels[indices]

    print(f"Reducing dimensions using {method.upper()}...")
    if method == 'tsne':
        from sklearn.manifold import TSNE
        reducer = TSNE(n_components=2, random_state=42, verbose=1)
    else:  # UMAP
        import umap
        reducer = umap.UMAP(n_components=2, random_state=42)

    # Reduce dimensionality
    reduced_embeddings = reducer.fit_transform(embeddings)

    # Plot
    plt.figure(figsize=(12, 10))

    # Get unique labels
    unique_labels = np.unique(labels)
    print(f"Plotting {len(unique_labels)} unique categories")

    # If too many labels, group the less common ones
    if len(unique_labels) > 20:
        # Count label frequencies
        label_counts = pd.Series(labels).value_counts()
        common_labels = label_counts.nlargest(19).index.tolist()

        # Create a new series with 'Other' for less common labels
        plot_labels = pd.Series(labels).copy()
        plot_labels[~plot_labels.isin(common_labels)] = 'Other'
        unique_plot_labels = np.unique(plot_labels)

        print(f"Grouped less common categories into 'Other', plotting {len(unique_plot_labels)} categories")
    else:
        plot_labels = labels
        unique_plot_labels = unique_labels

    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_plot_labels)))

    # Create a legend tracker
    legend_handles = []

    for i, label in enumerate(unique_plot_labels):
        mask = plot_labels == label
        scatter = plt.scatter(
            reduced_embeddings[mask, 0],
            reduced_embeddings[mask, 1],
            c=[colors[i]],
            label=label,
            alpha=0.7,
            s=30  # Point size
        )
        legend_handles.append(scatter)

    # Add legend with smaller font size
    if len(unique_plot_labels) <= 20:
        plt.legend(fontsize=8)
    else:
        # Create a separate legend
        plt.figlegend(
            handles=legend_handles,
            labels=unique_plot_labels,
            loc='center right',
            fontsize=8,
            bbox_to_anchor=(1.15, 0.5)
        )

    plt.title(f'ICD2Vec Embeddings Visualization ({method.upper()})')
    plt.tight_layout()

    return plt

# Load embeddings and visualize
# Choose a model to visualize
best_model = None
try:
    # Try to determine the best model from evaluation results
    with open('results/evaluation_results.json', 'r') as f:
        eval_results = json.load(f)

    # Choose model with highest classification accuracy
    if eval_results:
        scores = {model: results.get('classification', {}).get('accuracy', 0)
                  for model, results in eval_results.items()}
        best_model = max(scores, key=scores.get)
        print(f"Selected best model: {best_model} (accuracy: {scores[best_model]:.3f})")
except:
    # If no evaluation results, use the first available embedding file
    if embedding_files:
        best_model = list(embedding_files.keys())[0]
        print(f"Using first available model: {best_model}")

if best_model:
    embedding_path = embedding_files[best_model]
    embeddings = np.load(embedding_path)

    # Visualize using both TSNE and UMAP
    for method in ['tsne', 'umap']:
        try:
            plt = visualize_embeddings(embeddings, labels, method=method)
            plt.savefig(f'results/{best_model}_{method}_visualization.png', dpi=300, bbox_inches='tight')
            print(f"Saved {method.upper()} visualization")
            plt.close()
        except Exception as e:
            print(f"Error visualizing with {method}: {e}")

# 9. Fine-tuning (Optional - Resource Intensive)

In [None]:
def fine_tune_bert_model(texts, model_name, output_dir, epochs=3):
    """
    Fine-tune a BERT model on ICD descriptions using MLM.
    Note: This can take several hours on Colab, even with GPU.

    Args:
        texts: List of text descriptions for ICD codes
        model_name: HuggingFace model name (e.g., 'bert-base-uncased')
        output_dir: Directory to save the model
        epochs: Number of training epochs
    """
    print(f"Fine-tuning {model_name} on {len(texts)} texts for {epochs} epochs...")

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = BertForMaskedLM.from_pretrained(model_name)

    # Prepare dataset - use a smaller subset if you have memory issues
    print("Preparing dataset...")
    dataset = ICDTextDataset(texts, tokenizer)

    # Data collator for MLM
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=epochs,
        per_device_train_batch_size=8,
        save_steps=10_000,
        save_total_limit=2,
        logging_dir=f"{output_dir}/logs",
        logging_steps=500,
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=dataset,
    )

    # Train model
    print("Starting training...")
    trainer.train()

    # Save model
    print("Saving model...")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    print(f"Fine-tuning complete! Model saved to {output_dir}")
    return model, tokenizer

# IMPORTANT: Only run this cell if you have access to a GPU and enough time
# This will take several hours even on a powerful Colab GPU

# Uncomment to check if GPU is available
print(f"GPU available: {torch.cuda.is_available()}")
print(f"GPU device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# Uncomment and run if you want to fine-tune

# Create directory for fine-tuned models
os.makedirs('fine_tuned_models', exist_ok=True)

# Choose a subset of data for fine-tuning to make it manageable
# Use more data if you have time and resources
max_samples = 5000  # Adjust based on your resources
fine_tune_texts = icd11_df['transformer_text'].sample(min(max_samples, len(icd11_df)), random_state=42).tolist()

# Choose a model to fine-tune
model_to_finetune = 'emilyalsentzer/Bio_ClinicalBERT'  # Clinical focus is good for ICD codes
output_dir = 'fine_tuned_models/bioclinicalbert_ft'

# Fine-tune model
fine_tuned_model, fine_tuned_tokenizer = fine_tune_bert_model(
    fine_tune_texts,
    model_to_finetune,
    output_dir,
    epochs=2  # Start with fewer epochs for testing
)

# After fine-tuning, you can generate embeddings using the fine-tuned model
fine_tuned_embeddings = generate_bert_embeddings(
    icd11_df['transformer_text'].tolist(),
    output_dir
)

# Save the fine-tuned embeddings
np.save('embeddings/bioclinicalbert_ft_embeddings.npy', fine_tuned_embeddings)
print("Saved fine-tuned embeddings")


# 10. Putting It All Together - Simple Test Run

In [None]:
# This cell demonstrates a simplified workflow for a small sample
# It's useful for quick testing without running the whole pipeline

# 1. Create a small sample
small_sample_size = 200  # Adjust based on your needs
small_sample = icd11_df.sample(small_sample_size, random_state=42)
texts = small_sample['transformer_text'].tolist()
codes = small_sample['code'].tolist()
labels = pd.Series(codes).str.split('.').str[0]  # Top-level categories

print(f"Working with sample of {len(texts)} examples")

# 2. Generate embeddings with a single model
model_name = 'emilyalsentzer/Bio_ClinicalBERT'  # A good model for medical text
embeddings = generate_bert_embeddings(texts, model_name, batch_size=8)

# 3. Evaluate embeddings
eval_results = evaluate_embeddings(embeddings, labels)
print("\nEvaluation results:")
print(json.dumps(eval_results, indent=2))

# 4. Visualize embeddings
plt = visualize_embeddings(embeddings, labels, method='tsne')
plt.savefig('sample_visualization.png')
plt.show()

print("\nTest run completed successfully!")