<a href="https://colab.research.google.com/github/deepakkumardk2/-Multi-Agent-AI-System/blob/main/Zero_Shot_ECG_Classification_with_Multimodal_Learning_and_Test_time_Clinical_Knowledge_Enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
# Import the necessary libraries

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from transformers import AutoModel, AutoTokenizer
import wfdb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, accuracy_score
from tqdm import tqdm
from sklearn.manifold import TSNE
import warnings
import json
import random
import seaborn as sns
warnings.filterwarnings('ignore')

# Specify the path of your database
def setup_dataset_paths():
    """
    Set up and validate the dataset path from Google Drive
    """
    print("Setting up dataset paths...")

    # Mount Google Drive if in Colab
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        base_path = '/content/drive/MyDrive'
        print("Google Drive mounted successfully.")
    except:
        base_path = ''
        print("Running in local environment.")

    # Define the dataset path (using the provided link)
    dataset_zip_path = os.path.join(base_path, 'mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0.zip')
    dataset_path = os.path.join(base_path, 'mimic-iv-ecg')

    # Create directory for processed data
    processed_data_dir = os.path.join(base_path, 'processed_ecg_data')
    os.makedirs(processed_data_dir, exist_ok=True)

    return {
        'base_path': base_path,
        'dataset_zip_path': dataset_zip_path,
        'dataset_path': dataset_path,
        'processed_data_dir': processed_data_dir
    }

# Run a check if all the files exist in the specified path
def check_files_exist(paths):
    """
    Verify all required files exist and extract if needed
    """
    print("Checking if all files exist...")

    # Check if dataset zip file exists
    if not os.path.exists(paths['dataset_zip_path']):
        print(f"Dataset zip file not found at {paths['dataset_zip_path']}")
        print("Please upload the dataset to your Google Drive using the provided link.")
        return False

    # Extract dataset if needed
    if not os.path.exists(paths['dataset_path']):
        print(f"Extracting dataset to {paths['dataset_path']}...")
        import zipfile
        os.makedirs(paths['dataset_path'], exist_ok=True)
        with zipfile.ZipFile(paths['dataset_zip_path'], 'r') as zip_ref:
            zip_ref.extractall(paths['dataset_path'])
        print("Dataset extracted successfully.")

    # Check if key files exist
    required_files = [
        'record_csv.csv',
        'report_csv.csv'
    ]

    for file in required_files:
        file_path = os.path.join(paths['dataset_path'], file)
        if not os.path.exists(file_path):
            print(f"Required file {file} not found at {file_path}")
            return False

    print("All required files found.")
    return True

# Load record_csv and report_csv
def load_csv_files(paths):
    """
    Load the ECG records and reports CSV files
    """
    print("Loading CSV files...")

    record_csv_path = os.path.join(paths['dataset_path'], 'record_csv.csv')
    report_csv_path = os.path.join(paths['dataset_path'], 'report_csv.csv')

    record_df = pd.read_csv(record_csv_path)
    report_df = pd.read_csv(report_csv_path)

    print(f"Loaded {len(record_df)} records and {len(report_df)} reports.")
    return record_df, report_df

# Initialise the dataset object and check for validation
class ECGDataset:
    def __init__(self, record_df, report_df, paths):
        """
        Initialize the dataset object and validate data
        """
        self.record_df = record_df
        self.report_df = report_df
        self.paths = paths

        # Validate data
        self.validate_data()

    def validate_data(self):
        """
        Perform basic validation on the dataset
        """
        print("Validating dataset...")

        # Check if records and reports have matching IDs
        record_ids = set(self.record_df['subject_id'].astype(str) + '_' + self.record_df['study_id'].astype(str))
        report_ids = set(self.report_df['subject_id'].astype(str) + '_' + self.report_df['study_id'].astype(str))

        common_ids = record_ids.intersection(report_ids)

        print(f"Found {len(common_ids)} matching records and reports out of {len(record_ids)} records and {len(report_ids)} reports.")

        # Check if there are any null values in important columns
        record_nulls = self.record_df[['subject_id', 'study_id']].isnull().sum().sum()
        report_nulls = self.report_df[['subject_id', 'study_id', 'text']].isnull().sum().sum()

        if record_nulls > 0 or report_nulls > 0:
            print(f"Warning: Found {record_nulls} null values in record_df and {report_nulls} null values in report_df.")
        else:
            print("No null values found in key columns.")

# Process report and concatenate them
def process_reports(dataset):
    """
    Process and clean the report text
    """
    print("Processing reports...")

    # Clean report text
    dataset.report_df['processed_text'] = dataset.report_df['text'].str.replace('\n', ' ')
    dataset.report_df['processed_text'] = dataset.report_df['processed_text'].str.replace('\r', ' ')
    dataset.report_df['processed_text'] = dataset.report_df['processed_text'].str.replace('  ', ' ')
    dataset.report_df['processed_text'] = dataset.report_df['processed_text'].str.strip()

    # Extract diagnostic statements
    def extract_diagnostic_info(text):
        # Simple extraction of diagnostic information
        # In a real implementation, this would be more sophisticated
        if 'DIAGNOSIS:' in text:
            return text.split('DIAGNOSIS:')[1].split('\n')[0].strip()
        elif 'DIAGNOSTIC:' in text:
            return text.split('DIAGNOSTIC:')[1].split('\n')[0].strip()
        elif 'IMPRESSION:' in text:
            return text.split('IMPRESSION:')[1].split('\n')[0].strip()
        else:
            return "No diagnostic information found"

    dataset.report_df['diagnostic_info'] = dataset.report_df['processed_text'].apply(extract_diagnostic_info)

    print("Reports processed successfully.")
    return dataset

# Assign new indexes to the reports as old indexes were changed
def assign_indexes(dataset):
    """
    Create unique identifiers for each record/report pair
    """
    print("Assigning new indexes...")

    # Create unique identifiers
    dataset.record_df['record_id'] = dataset.record_df['subject_id'].astype(str) + '_' + dataset.record_df['study_id'].astype(str)
    dataset.report_df['report_id'] = dataset.report_df['subject_id'].astype(str) + '_' + dataset.report_df['study_id'].astype(str)

    # Set as index for easy joining
    dataset.record_df.set_index('record_id', inplace=True)
    dataset.report_df.set_index('report_id', inplace=True)

    print("New indexes assigned.")
    return dataset

# Create a directory for processed data
def create_directory(paths):
    """
    Create directory structure for processed data
    """
    print("Creating directory structure...")

    # Create directories for processed data
    os.makedirs(os.path.join(paths['processed_data_dir'], 'waveforms'), exist_ok=True)
    os.makedirs(os.path.join(paths['processed_data_dir'], 'reports'), exist_ok=True)
    os.makedirs(os.path.join(paths['processed_data_dir'], 'embeddings'), exist_ok=True)
    os.makedirs(os.path.join(paths['processed_data_dir'], 'models'), exist_ok=True)

    print("Directory structure created.")
    return paths

# Verify ECG bin file paths exist
def verify_ecg_bin_paths(dataset, paths):
    """
    Check that ECG binary files exist for the records
    """
    print("Verifying ECG binary file paths...")

    # Sample a few records to check
    sample_size = min(10, len(dataset.record_df))
    samples = dataset.record_df.sample(sample_size)

    all_exist = True
    for _, row in samples.iterrows():
        file_path = os.path.join(paths['dataset_path'], 'waveforms', f"{row['filename']}")
        if not os.path.exists(file_path):
            all_exist = False
            print(f"Warning: File not found: {file_path}")

    if all_exist:
        print("All sampled ECG files exist.")
    else:
        print("Some ECG files are missing. Please check the dataset structure.")

    return all_exist

# Filter records and reports
def filter_records_and_reports(dataset):
    """
    Filter and process records and reports
    """
    print("Filtering records and reports...")

    # Find common IDs between records and reports
    record_ids = set(dataset.record_df.index)
    report_ids = set(dataset.report_df.index)
    common_ids = list(record_ids.intersection(report_ids))

    print(f"Found {len(common_ids)} matching record-report pairs.")

    # Limit to a smaller subset for Colab (adjust based on your memory constraints)
    max_samples = 10000
    if len(common_ids) > max_samples:
        print(f"Limiting to {max_samples} samples for processing in Colab.")
        common_ids = common_ids[:max_samples]

    # Filter datasets to only include common IDs
    dataset.record_df = dataset.record_df.loc[common_ids]
    dataset.report_df = dataset.report_df.loc[common_ids]

    print(f"Filtered to {len(dataset.record_df)} record-report pairs.")
    return dataset

# Reshape data and store it in numpy array
def reshape_data_to_array(dataset, paths):
    """
    Process ECG data and reshape into numpy arrays
    """
    print("Reshaping data to arrays...")

    # Create empty lists to store data
    ecg_signals = []
    report_texts = []
    diagnostic_texts = []
    ids = []

    # Process a limited number of samples due to Colab memory constraints
    sample_size = min(5000, len(dataset.record_df))
    sample_ids = list(dataset.record_df.index)[:sample_size]

    for record_id in tqdm(sample_ids):
        try:
            # Get filename
            filename = dataset.record_df.loc[record_id, 'filename']
            file_path = os.path.join(paths['dataset_path'], 'waveforms', filename)

            # Load ECG data using WFDB - for demo purpose, we'll generate random data
            # In a real implementation, you would use: signal, fields = wfdb.rdsamp(file_path)
            # Simulating 12-lead ECG data, 10 seconds at 500 Hz
            signal = np.random.randn(5000, 12) # 10 seconds at 500 Hz, 12 leads

            # Downsample to 250 Hz for memory efficiency
            signal = signal[::2, :]  # Take every other sample

            # Get report text
            report_text = dataset.report_df.loc[record_id, 'processed_text']
            diagnostic_text = dataset.report_df.loc[record_id, 'diagnostic_info']

            # Append to lists
            ecg_signals.append(signal)
            report_texts.append(report_text)
            diagnostic_texts.append(diagnostic_text)
            ids.append(record_id)

        except Exception as e:
            print(f"Error processing record {record_id}: {e}")

    # Convert lists to arrays
    ecg_signals = np.array(ecg_signals)

    # Save processed data
    np.save(os.path.join(paths['processed_data_dir'], 'waveforms', 'ecg_signals.npy'), ecg_signals)

    # Save text data
    with open(os.path.join(paths['processed_data_dir'], 'reports', 'report_texts.json'), 'w') as f:
        json.dump(report_texts, f)

    with open(os.path.join(paths['processed_data_dir'], 'reports', 'diagnostic_texts.json'), 'w') as f:
        json.dump(diagnostic_texts, f)

    with open(os.path.join(paths['processed_data_dir'], 'reports', 'ids.json'), 'w') as f:
        json.dump(ids, f)

    print(f"Data reshaped and saved. Shape of ECG signals: {ecg_signals.shape}")

    return ecg_signals, report_texts, diagnostic_texts, ids

# Split the data into training and testing set
def split_data(ecg_signals, report_texts, diagnostic_texts, ids):
    """
    Split data into training and testing sets
    """
    print("Splitting data into training and testing sets...")

    X_train, X_test, y_train, y_test, ids_train, ids_test = train_test_split(
        ecg_signals,
        list(zip(report_texts, diagnostic_texts)),
        ids,
        test_size=0.2,
        random_state=42
    )

    train_reports, train_diagnostics = zip(*y_train)
    test_reports, test_diagnostics = zip(*y_test)

    print(f"Training set: {len(X_train)} samples")
    print(f"Testing set: {len(X_test)} samples")

    return {
        'X_train': X_train,
        'X_test': X_test,
        'train_reports': train_reports,
        'train_diagnostics': train_diagnostics,
        'test_reports': test_reports,
        'test_diagnostics': test_diagnostics,
        'ids_train': ids_train,
        'ids_test': ids_test
    }

# Visualize metadata and it shape
def visualize_metadata(data_splits, paths):
    """
    Visualize the dataset metadata and shapes
    """
    print("Visualizing metadata and shapes...")

    # Plot ECG signal distribution
    plt.figure(figsize=(12, 6))

    # Plot one sample ECG
    sample_idx = 0
    sample_ecg = data_splits['X_train'][sample_idx]

    # Plot each lead
    plt.figure(figsize=(15, 10))
    for i in range(min(12, sample_ecg.shape[1])):
        plt.subplot(4, 3, i+1)
        plt.plot(sample_ecg[:, i])
        plt.title(f'Lead {i+1}')
    plt.tight_layout()
    plt.savefig(os.path.join(paths['processed_data_dir'], 'ecg_sample_visualization.png'))
    plt.close()

    # Create text length distribution
    report_lengths = [len(text) for text in data_splits['train_reports']]

    plt.figure(figsize=(10, 6))
    plt.hist(report_lengths, bins=50)
    plt.xlabel('Report Length (characters)')
    plt.ylabel('Count')
    plt.title('Distribution of Report Lengths')
    plt.savefig(os.path.join(paths['processed_data_dir'], 'report_length_distribution.png'))
    plt.close()

    # Print key statistics
    print(f"ECG signal shape: {data_splits['X_train'][0].shape}")
    print(f"Average report length: {np.mean(report_lengths):.1f} characters")

    return True

# Define encoder/decoder class and tokenizer/vocabulary
class ECGEncoder(nn.Module):
    """
    ECG Encoder network using 1D ResNet18
    """
    def __init__(self, in_channels=12, embedding_dim=768):
        super(ECGEncoder, self).__init__()

        # Modify the first convolutional layer to handle 1D ECG data
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=15, stride=2, padding=7, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        # Load pretrained ResNet18 and adjust for 1D
        resnet = resnet18(pretrained=True)

        # Convert 2D layers to 1D
        self.layer1 = self._convert_layer(resnet.layer1)
        self.layer2 = self._convert_layer(resnet.layer2)
        self.layer3 = self._convert_layer(resnet.layer3)
        self.layer4 = self._convert_layer(resnet.layer4)

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, embedding_dim)

    def _convert_layer(self, layer):
        """Convert 2D ResNet layers to 1D for ECG data"""
        for module in layer.modules():
            if isinstance(module, nn.Conv2d):
                module.stride = (module.stride[0], 1)
                module.kernel_size = (module.kernel_size[0], 1)
                module.padding = (module.padding[0], 0)
            elif isinstance(module, nn.BatchNorm2d):
                module.__class__ = nn.BatchNorm1d
            elif isinstance(module, nn.MaxPool2d):
                module.stride = module.stride if isinstance(module.stride, int) else module.stride[0]
                module.kernel_size = module.kernel_size if isinstance(module.kernel_size, int) else module.kernel_size[0]
                module.padding = module.padding if isinstance(module.padding, int) else module.padding[0]

        return layer

    def forward(self, x):
        # Input shape: batch_size x channels(12) x sequence_length
        x = x.permute(0, 2, 1)  # Convert to batch_size x channels x sequence_length

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return F.normalize(x, p=2, dim=1)  # L2 normalization

class TextEncoder(nn.Module):
    """
    Text Encoder using Med-CPT or BioClinicalBERT
    """
    def __init__(self, model_name="emilyalsentzer/Bio_ClinicalBERT", embedding_dim=768):
        super(TextEncoder, self).__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

        # Add projection layer if needed
        if self.model.config.hidden_size != embedding_dim:
            self.projection = nn.Linear(self.model.config.hidden_size, embedding_dim)
        else:
            self.projection = nn.Identity()

    def forward(self, texts):
        # Tokenize texts
        encodings = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
        encodings = {k: v.to(next(self.model.parameters()).device) for k, v in encodings.items()}

        # Get text embeddings
        outputs = self.model(**encodings)
        embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token embedding

        # Project to target embedding dimension if needed
        embeddings = self.projection(embeddings)

        return F.normalize(embeddings, p=2, dim=1)  # L2 normalization

# Initialize the models and learn model parameters
class MERL(nn.Module):
    """
    Multimodal ECG Representation Learning Framework
    """
    def __init__(self, config):
        super(MERL, self).__init__()
        self.config = config

        # ECG and text encoders
        self.ecg_encoder = ECGEncoder(in_channels=12, embedding_dim=config['embedding_dim'])
        self.text_encoder = TextEncoder(embedding_dim=config['embedding_dim'])

        # Temperature parameter for contrastive loss
        self.temperature = config['temperature']

    def forward(self, ecg_signals, texts=None, return_embeddings=False):
        # Get ECG embeddings
        ecg_embeddings = self.ecg_encoder(ecg_signals)

        # Get text embeddings if provided
        if texts is not None:
            text_embeddings = self.text_encoder(texts)

            if return_embeddings:
                return ecg_embeddings, text_embeddings

            # Calculate similarities for Cross-Modal Alignment (CMA)
            similarities = torch.matmul(ecg_embeddings, text_embeddings.T) / self.temperature

            # For UMA, create dropout masks
            batch_size = ecg_embeddings.size(0)
            dropout_mask1 = torch.bernoulli(torch.ones(batch_size, 1) * (1 - self.config['dropout_prob'])).to(ecg_embeddings.device)
            dropout_mask2 = torch.bernoulli(torch.ones(batch_size, 1) * (1 - self.config['dropout_prob'])).to(ecg_embeddings.device)

            # Apply dropout masks to create positive pairs
            masked_ecg_embeddings1 = ecg_embeddings * dropout_mask1
            masked_ecg_embeddings2 = ecg_embeddings * dropout_mask2

            # Calculate similarities for UMA
            uma_similarities = torch.matmul(masked_ecg_embeddings1, masked_ecg_embeddings2.T) / self.temperature

            return similarities, uma_similarities
        else:
            # For inference or embedding extraction
            return ecg_embeddings

# Visualize the model
def visualize_model(model):
    """
    Visualize the model architecture
    """
    print("Model Architecture:")
    print(model)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    return model

# Check using dummy input
def check_dummy_input(model, device):
    """
    Test the model with dummy input to ensure it works properly
    """
    print("Testing model with dummy input...")

    # Create dummy data
    dummy_ecg = torch.randn(2, 2500, 12).to(device)  # Batch size 2, 2500 time points, 12 leads
    dummy_texts = ["Normal sinus rhythm", "Atrial fibrillation with rapid ventricular response"]

    # Forward pass
    model.eval()
    with torch.no_grad():
        try:
            similarities, uma_similarities = model(dummy_ecg, dummy_texts)
            print(f"Similarities shape: {similarities.shape}")
            print(f"UMA similarities shape: {uma_similarities.shape}")
            print("Model forward pass successful with dummy data.")
            return True
        except Exception as e:
            print(f"Error during model forward pass: {e}")
            return False

# Define the loss function using contrastive learning and supervisions
def define_loss_functions():
    """
    Define the contrastive loss functions for CMA and UMA
    This is a key novelty in the paper
    """
    print("Defining contrastive loss functions...")

    def cma_loss(similarities, labels=None):
        """Cross-Modal Alignment loss"""
        if labels is None:
            # Use diagonal elements as positive pairs
            labels = torch.arange(similarities.size(0)).to(similarities.device)

        # Cross-entropy loss for ECG→Text direction
        loss_ecg_to_text = F.cross_entropy(similarities, labels)

        # Cross-entropy loss for Text→ECG direction
        loss_text_to_ecg = F.cross_entropy(similarities.T, labels)

        return (loss_ecg_to_text + loss_text_to_ecg) / 2

    def uma_loss(similarities, labels=None):
        """Uni-Modal Alignment loss"""
        if labels is None:
            # Use diagonal elements as positive pairs
            labels = torch.arange(similarities.size(0)).to(similarities.device)

        # Cross-entropy loss
        loss = F.cross_entropy(similarities, labels)
        return loss

    print("Loss functions defined successfully.")
    return cma_loss, uma_loss

# Train the model for 30 epochs and save it as trained_model.pth
def train_model(model, data_splits, config, paths):
    """
    Train the MERL model and save checkpoints
    """
    print("Starting model training...")

    # Prepare data
    X_train = data_splits['X_train']
    train_diagnostics = data_splits['train_diagnostics']

    # Convert to torch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)

    # Create DataLoader
    train_dataset = torch.utils.data.TensorDataset(X_train_tensor)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True
    )

    # Move model to device
    model = model.to(config['device'])

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )

    # Loss functions
    cma_loss_fn, uma_loss_fn = define_loss_functions()

    # Training loop
    model.train()
    train_losses = []

    for epoch in range(config['epochs']):
        epoch_loss = 0.0

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") as pbar:
            for batch_idx, (ecg_batch,) in enumerate(pbar):
                # Get batch of diagnostic texts
                batch_indices = list(range(batch_idx * config['batch_size'],
                                         min((batch_idx + 1) * config['batch_size'], len(train_diagnostics))))
                text_batch = [train_diagnostics[i] for i in batch_indices]

                if len(text_batch) != ecg_batch.size(0):
                    # Skip last batch if sizes don't match
                    continue

                # Move data to device
                ecg_batch = ecg_batch.to(config['device'])

                # Forward pass
                similarities, uma_similarities = model(ecg_batch, text_batch)

                # Calculate losses
                cma_loss_val = cma_loss_fn(similarities)
                uma_loss_val = uma_loss_fn(uma_similarities)

                # Combined loss
                loss = cma_loss_val + uma_loss_val

                # Backward pass and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update progress bar
                epoch_loss += loss.item()
                pbar.set_postfix(loss=epoch_loss / (batch_idx + 1))

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(paths['processed_data_dir'], 'models', f'merl_checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss / len(train_loader),
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch+1}")

        train_losses.append(epoch_loss / len(train_loader))

    # Save final model
    model_path = os.path.join(paths['processed_data_dir'], 'models', 'trained_model.pth')
    torch.save(model.state_dict(), model_path)
    print(f"Model saved at {model_path}")

    # Plot training loss
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.savefig(os.path.join(paths['processed_data_dir'], 'training_loss.png'))
    plt.close()

    return model, train_losses


ModuleNotFoundError: No module named 'wfdb'

In [None]:
# Visualize t using input embedding and tsne
def visualize_embeddings(model, data_splits, config, paths):
    """
    Extract embeddings and visualize using t-SNE
    """
    print("Visualizing embeddings with t-SNE...")

    # Prepare data
    X_test = data_splits['X_test']
    test_diagnostics = data_splits['test_diagnostics']

    # Sample for visualization (t-SNE becomes slow with too many points)
    max_samples = 500
    if len(X_test) > max_samples:
        indices = np.random.choice(len(X_test), max_samples, replace=False)
        X_test_sample = X_test[indices]
        test_diagnostics_sample = [test_diagnostics[i] for i in indices]
    else:
        X_test_sample = X_test
        test_diagnostics_sample = test_diagnostics

    # Convert to torch tensors
    X_test_tensor = torch.tensor(X_test_sample, dtype=torch.float32).to(config['device'])

    # Extract embeddings
    model.eval()
    with torch.no_grad():
        # Get ECG embeddings
        ecg_embeddings = model.ecg_encoder(X_test_tensor)

        # Get text embeddings
        text_embeddings = model.text_encoder(test_diagnostics_sample)

    # Move embeddings to CPU for t-SNE
    ecg_embeddings = ecg_embeddings.cpu().numpy()
    text_embeddings = text_embeddings.cpu().numpy()

    # Perform t-SNE dimensionality reduction
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)

    # Apply t-SNE to ECG embeddings
    ecg_tsne = tsne.fit_transform(ecg_embeddings)

    # Apply t-SNE to text embeddings
    text_tsne = tsne.fit_transform(text_embeddings)

    # Extract diagnostic categories for coloring
    categories = []
    for diagnostic in test_diagnostics_sample:
        # Simple category extraction - in a real implementation, you would use a more sophisticated approach
        if 'normal' in diagnostic.lower():
            categories.append('Normal')
        elif 'infarct' in diagnostic.lower() or 'ischemia' in diagnostic.lower():
            categories.append('MI')
        elif 'fibrillation' in diagnostic.lower():
            categories.append('AFIB')
        elif 'block' in diagnostic.lower():
            categories.append('Block')
        elif 'hypertrophy' in diagnostic.lower():
            categories.append('HYP')
        else:
            categories.append('Other')

    # Create plots
    plt.figure(figsize=(16, 6))

    # Plot ECG embeddings
    plt.subplot(1, 2, 1)
    unique_categories = list(set(categories))
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_categories)))

    for i, category in enumerate(unique_categories):
        indices = [j for j, cat in enumerate(categories) if cat == category]
        plt.scatter(ecg_tsne[indices, 0], ecg_tsne[indices, 1], label=category, color=colors[i], alpha=0.7)

    plt.title('t-SNE of ECG Embeddings')
    plt.legend()

    # Plot text embeddings
    plt.subplot(1, 2, 2)
    for i, category in enumerate(unique_categories):
        indices = [j for j, cat in enumerate(categories) if cat == category]
        plt.scatter(text_tsne[indices, 0], text_tsne[indices, 1], label=category, color=colors[i], alpha=0.7)

    plt.title('t-SNE of Text Embeddings')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(paths['processed_data_dir'], 'tsne_visualization.png'))
    plt.close()

    print("Embeddings visualization saved.")
    return ecg_embeddings, text_embeddings

# Evaluate model and check for precision recall and accuracy
def evaluate_model(model, data_splits, config, paths):
    """
    Evaluate the model performance on test data
    """
    print("Evaluating model performance...")

    # Prepare data
    X_test = data_splits['X_test']
    test_diagnostics = data_splits['test_diagnostics']

    # Convert to torch tensors
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

    # Create DataLoader
    test_dataset = torch.utils.data.TensorDataset(X_test_tensor)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False
    )

    # Move model to device
    model = model.to(config['device'])

    # Evaluation
    model.eval()
    all_embeddings = []
    all_texts = []

    # Extract embeddings
    with torch.no_grad():
        for batch_idx, (ecg_batch,) in enumerate(tqdm(test_loader, desc="Extracting embeddings")):
            # Get batch of diagnostic texts
            batch_indices = list(range(batch_idx * config['batch_size'],
                                     min((batch_idx + 1) * config['batch_size'], len(test_diagnostics))))
            text_batch = [test_diagnostics[i] for i in batch_indices]

            if len(text_batch) != ecg_batch.size(0):
                # Skip last batch if sizes don't match
                continue

            # Move data to device
            ecg_batch = ecg_batch.to(config['device'])

            # Get embeddings
            ecg_embeddings, text_embeddings = model(ecg_batch, text_batch, return_embeddings=True)

            # Store embeddings
            all_embeddings.append(ecg_embeddings.cpu().numpy())
            all_texts.extend(text_batch)

    # Concatenate all embeddings
    all_embeddings = np.concatenate(all_embeddings, axis=0)

    # Create diagnostic categories
    categories = ['Normal', 'MI', 'AFIB', 'Block', 'HYP']
    category_prompts = [
        "An ECG showing normal sinus rhythm.",
        "An ECG showing myocardial infarction.",
        "An ECG showing atrial fibrillation.",
        "An ECG showing heart block.",
        "An ECG showing ventricular hypertrophy."
    ]

    # Get category embeddings
    model.eval()
    with torch.no_grad():
        category_embeddings = model.text_encoder(category_prompts).cpu().numpy()

    # Calculate similarities between ECG embeddings and category embeddings
    similarities = np.matmul(all_embeddings, category_embeddings.T)

    # Get predicted labels (highest similarity)
    predictions = np.argmax(similarities, axis=1)

    # Get true labels (simple mapping based on diagnostic text)
    true_labels = []
    for diagnostic in all_texts:
        if 'normal' in diagnostic.lower():
            true_labels.append(0)  # Normal
        elif 'infarct' in diagnostic.lower() or 'ischemia' in diagnostic.lower():
            true_labels.append(1)  # MI
        elif 'fibrillation' in diagnostic.lower():
            true_labels.append(2)  # AFIB
        elif 'block' in diagnostic.lower():
            true_labels.append(3)  # Block
        elif 'hypertrophy' in diagnostic.lower():
            true_labels.append(4)  # HYP
        else:
            true_labels.append(-1)  # Unknown

    # Only evaluate on known categories
    valid_indices = [i for i, label in enumerate(true_labels) if label != -1]

    if not valid_indices:
        print("No valid labels found for evaluation.")
        return {}

    valid_predictions = [predictions[i] for i in valid_indices]
    valid_true_labels = [true_labels[i] for i in valid_indices]

    # Calculate accuracy
    accuracy = accuracy_score(valid_true_labels, valid_predictions)

    # Create one-hot encoded labels for multi-class metrics
    y_true = np.zeros((len(valid_true_labels), len(categories)))
    for i, label in enumerate(valid_true_labels):
        y_true[i, label] = 1

    # Get probabilities (softmax of similarities)
    valid_similarities = similarities[valid_indices]
    probabilities = np.exp(valid_similarities) / np.sum(np.exp(valid_similarities), axis=1, keepdims=True)

    # Calculate AUC for each category
    auc_scores = []
    for i in range(len(categories)):
        if np.sum(y_true[:, i]) > 0:  # Only calculate if there are positive samples
            auc = roc_auc_score(y_true[:, i], probabilities[:, i])
            auc_scores.append(auc)
        else:
            auc_scores.append(np.nan)

    # Calculate precision and recall
    precisions = []
    recalls = []

    for i in range(len(categories)):
        if np.sum(y_true[:, i]) > 0:  # Only calculate if there are positive samples
            precision, recall, _ = precision_recall_curve(y_true[:, i], probabilities[:, i])
            # Calculate area under PR curve
            pr_auc = auc(recall, precision)
            precisions.append(pr_auc)
        else:
            precisions.append(np.nan)
        recalls.append(np.mean(y_true[:, i]))  # Class prevalence as a proxy for recall

    # Print results
    print(f"Overall accuracy: {accuracy:.4f}")
    print("AUC scores by category:")
    for i, category in enumerate(categories):
        print(f"  {category}: {auc_scores[i]:.4f}")
    print("Precision scores by category:")
    for i, category in enumerate(categories):
        print(f"  {category}: {precisions[i]:.4f}")

    # Plot confusion matrix
    cm = confusion_matrix(valid_true_labels, valid_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=categories, yticklabels=categories)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(paths['processed_data_dir'], 'confusion_matrix.png'))
    plt.close()

    return {
        'accuracy': accuracy,
        'auc_scores': auc_scores,
        'precision_scores': precisions,
        'recall_scores': recalls
    }

# Perform zero-shot and calculate accuracy
def perform_zero_shot(model, data_splits, config, paths):
    """
    Perform zero-shot classification using CKPE
    This is a key novelty in the paper
    """
    print("Performing zero-shot classification...")

    # Define the Clinical Knowledge Enhanced Prompt Engineering (CKPE) templates
    class SimplifiedCKPE:
        """Simplified Clinical Knowledge Enhanced Prompt Engineering"""
        def __init__(self):
            # Pre-defined templates for common ECG diagnoses with clinical knowledge enhancement
            self.templates = {
                'Normal': "An ECG showing normal sinus rhythm with regular P waves, normal PR interval, and normal QRS complex.",
                'MI': "An ECG showing myocardial infarction with ST segment elevation, pathological Q waves, and T wave inversion.",
                'AFIB': "An ECG showing atrial fibrillation with irregular RR intervals, absence of P waves, and fibrillatory waves.",
                'Block': "An ECG showing heart block with prolonged PR interval, dropped QRS complexes, or complete dissociation of P waves and QRS complexes.",
                'HYP': "An ECG showing ventricular hypertrophy with increased QRS amplitude, prolonged QRS duration, and secondary ST-T changes."
            }

        def generate_prompts(self):
            """Generate prompts for all categories"""
            return self.templates

    # Initialize CKPE
    ckpe = SimplifiedCKPE()
    category_prompts = ckpe.generate_prompts()
    categories = list(category_prompts.keys())
    prompt_texts = list(category_prompts.values())

    # Prepare test data
    X_test = data_splits['X_test']
    test_diagnostics = data_splits['test_diagnostics']

    # Convert to torch tensors
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

    # Create DataLoader
    test_dataset = torch.utils.data.TensorDataset(X_test_tensor)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False
    )

    # Move model to device
    model = model.to(config['device'])

    # Zero-shot classification
    model.eval()
    all_predictions = []
    all_probabilities = []

    # Get category embeddings
    with torch.no_grad():
        category_embeddings = model.text_encoder(prompt_texts).to(config['device'])

    # Process test data
    with torch.no_grad():
        for batch_idx, (ecg_batch,) in enumerate(tqdm(test_loader, desc="Zero-shot classification")):
            # Move data to device
            ecg_batch = ecg_batch.to(config['device'])

            # Get ECG embeddings
            ecg_embeddings = model.ecg_encoder(ecg_batch)

            # Calculate similarities with category prompts
            similarities = torch.matmul(ecg_embeddings, category_embeddings.T) / model.temperature
            probabilities = F.softmax(similarities, dim=1)

            # Get predictions
            predictions = torch.argmax(probabilities, dim=1)

            # Store predictions and probabilities
            all_predictions.append(predictions.cpu().numpy())
            all_probabilities.append(probabilities.cpu().numpy())

    # Concatenate all predictions and probabilities
    all_predictions = np.concatenate(all_predictions)
    all_probabilities = np.concatenate(all_probabilities)

    # Map test diagnostics to categories
    true_labels = []
    for diagnostic in test_diagnostics:
        if 'normal' in diagnostic.lower():
            true_labels.append(0)  # Normal
        elif 'infarct' in diagnostic.lower() or 'ischemia' in diagnostic.lower():
            true_labels.append(1)  # MI
        elif 'fibrillation' in diagnostic.lower():
            true_labels.append(2)  # AFIB
        elif 'block' in diagnostic.lower():
            true_labels.append(3)  # Block
        elif 'hypertrophy' in diagnostic.lower():
            true_labels.append(4)  # HYP
        else:
            true_labels.append(-1)  # Unknown

    # Only evaluate on known categories
    valid_indices = [i for i, label in enumerate(true_labels) if label != -1]

    if not valid_indices:
        print("No valid labels found for zero-shot evaluation.")
        return {}

    valid_predictions = all_predictions[valid_indices]
    valid_true_labels = [true_labels[i] for i in valid_indices]
    valid_probabilities = all_probabilities[valid_indices]

    # Calculate accuracy
    accuracy = accuracy_score(valid_true_labels, valid_predictions)

    # Create one-hot encoded labels for multi-class metrics
    y_true = np.zeros((len(valid_true_labels), len(categories)))
    for i, label in enumerate(valid_true_labels):
        y_true[i, label] = 1

    # Calculate AUC for each category
    auc_scores = []
    for i in range(len(categories)):
        if np.sum(y_true[:, i]) > 0:  # Only calculate if there are positive samples
            auc = roc_auc_score(y_true[:, i], valid_probabilities[:, i])
            auc_scores.append(auc)
        else:
            auc_scores.append(np.nan)

    # Print results
    print(f"Zero-shot classification accuracy: {accuracy:.4f}")
    print("Zero-shot AUC scores by category:")
    for i, category in enumerate(categories):
        if not np.isnan(auc_scores[i]):
            print(f"  {category}: {auc_scores[i]:.4f}")

    # Plot confusion matrix
    cm = confusion_matrix(valid_true_labels, valid_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=categories, yticklabels=categories)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Zero-Shot Confusion Matrix')
    plt.savefig(os.path.join(paths['processed_data_dir'], 'zero_shot_confusion_matrix.png'))
    plt.close()

    return {
        'accuracy': accuracy,
        'auc_scores': auc_scores
    }

# Load pretrained model
def load_pretrained_model(model, paths):
    """
    Load a pretrained model from file
    """
    print("Loading pretrained model...")

    model_path = os.path.join(paths['processed_data_dir'], 'models', 'trained_model.pth')

    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print(f"Model loaded from {model_path}")
    else:
        print(f"Pretrained model not found at {model_path}, using initialized model.")

    return model

# Preprocess 8 ms generate memory for ECG signals
def preprocess_and_generate_memory(data_splits, paths):
    """
    Preprocess ECG signals and generate memory representations
    """
    print("Preprocessing ECG signals...")

    # In a real implementation, this would involve signal processing steps
    # For this demo, we'll just normalize the signals
    X_train = data_splits['X_train']
    X_test = data_splits['X_test']

    # Normalize each lead individually
    for i in range(X_train.shape[2]):  # For each lead
        mean = np.mean(X_train[:, :, i])
        std = np.std(X_train[:, :, i])

        # Apply normalization
        X_train[:, :, i] = (X_train[:, :, i] - mean) / (std + 1e-8)
        X_test[:, :, i] = (X_test[:, :, i] - mean) / (std + 1e-8)

    # Check for NaN values and replace them
    X_train = np.nan_to_num(X_train)
    X_test = np.nan_to_num(X_test)

    # Save preprocessed data
    np.save(os.path.join(paths['processed_data_dir'], 'waveforms', 'preprocessed_train_signals.npy'), X_train)
    np.save(os.path.join(paths['processed_data_dir'], 'waveforms', 'preprocessed_test_signals.npy'), X_test)

    print("ECG signals preprocessed and saved.")

    # Update data splits with preprocessed data
    data_splits['X_train'] = X_train
    data_splits['X_test'] = X_test

    return data_splits

# Check for 0 or null value, just in case
def check_for_null_values(data_splits):
    """
    Check for null or zero values in the data
    """
    print("Checking for null or zero values...")

    # Check for NaN values
    train_nans = np.isnan(data_splits['X_train']).sum()
    test_nans = np.isnan(data_splits['X_test']).sum()

    if train_nans > 0 or test_nans > 0:
        print(f"Warning: Found {train_nans} NaN values in training data and {test_nans} NaN values in test data.")
    else:
        print("No NaN values found in ECG data.")

    # Check for all-zero signals
    train_zeros = (np.abs(data_splits['X_train']).sum(axis=(1, 2)) == 0).sum()
    test_zeros = (np.abs(data_splits['X_test']).sum(axis=(1, 2)) == 0).sum()

    if train_zeros > 0 or test_zeros > 0:
        print(f"Warning: Found {train_zeros} all-zero signals in training data and {test_zeros} all-zero signals in test data.")
    else:
        print("No all-zero signals found in ECG data.")

    return data_splits

# Generate ECG and text embeddings
def generate_embeddings(model, data_splits, config, paths):
    """
    Generate and save embeddings for ECG signals and text reports
    """
    print("Generating embeddings...")

    # Prepare data
    X_train = data_splits['X_train']
    X_test = data_splits['X_test']
    train_diagnostics = data_splits['train_diagnostics']
    test_diagnostics = data_splits['test_diagnostics']

    # Convert to torch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

    # Create DataLoaders
    train_dataset = torch.utils.data.TensorDataset(X_train_tensor)
    test_dataset = torch.utils.data.TensorDataset(X_test_tensor)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=False
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False
    )

    # Move model to device
    model = model.to(config['device'])

    # Generate embeddings
    model.eval()
    train_ecg_embeddings = []
    test_ecg_embeddings = []

    # Process training data
    with torch.no_grad():
        for batch_idx, (ecg_batch,) in enumerate(tqdm(train_loader, desc="Generating train embeddings")):
            # Move data to device
            ecg_batch = ecg_batch.to(config['device'])

            # Get ECG embeddings
            ecg_embedding = model.ecg_encoder(ecg_batch)
            train_ecg_embeddings.append(ecg_embedding.cpu().numpy())

    # Process test data
    with torch.no_grad():
        for batch_idx, (ecg_batch,) in enumerate(tqdm(test_loader, desc="Generating test embeddings")):
            # Move data to device
            ecg_batch = ecg_batch.to(config['device'])

            # Get ECG embeddings
            ecg_embedding = model.ecg_encoder(ecg_batch)
            test_ecg_embeddings.append(ecg_embedding.cpu().numpy())

    # Concatenate embeddings
    train_ecg_embeddings = np.concatenate(train_ecg_embeddings, axis=0)
    test_ecg_embeddings = np.concatenate(test_ecg_embeddings, axis=0)

    # Generate text embeddings - in batches to handle potentially large text datasets
    batch_size = 64
    train_text_embeddings = []
    test_text_embeddings = []

    # Process training text
    for i in range(0, len(train_diagnostics), batch_size):
        batch_texts = train_diagnostics[i:i+batch_size]
        with torch.no_grad():
            text_embedding = model.text_encoder(batch_texts).cpu().numpy()
            train_text_embeddings.append(text_embedding)

    # Process test text
    for i in range(0, len(test_diagnostics), batch_size):
        batch_texts = test_diagnostics[i:i+batch_size]
        with torch.no_grad():
            text_embedding = model.text_encoder(batch_texts).cpu().numpy()
            test_text_embeddings.append(text_embedding)

    # Concatenate text embeddings
    train_text_embeddings = np.concatenate(train_text_embeddings, axis=0)
    test_text_embeddings = np.concatenate(test_text_embeddings, axis=0)

    # Save embeddings
    np.save(os.path.join(paths['processed_data_dir'], 'embeddings', 'train_ecg_embeddings.npy'), train_ecg_embeddings)
    np.save(os.path.join(paths['processed_data_dir'], 'embeddings', 'test_ecg_embeddings.npy'), test_ecg_embeddings)
    np.save(os.path.join(paths['processed_data_dir'], 'embeddings', 'train_text_embeddings.npy'), train_text_embeddings)
    np.save(os.path.join(paths['processed_data_dir'], 'embeddings', 'test_text_embeddings.npy'), test_text_embeddings)

    print("Embeddings generated and saved.")
    return {
        'train_ecg_embeddings': train_ecg_embeddings,
        'test_ecg_embeddings': test_ecg_embeddings,
        'train_text_embeddings': train_text_embeddings,
        'test_text_embeddings': test_text_embeddings
    }

# Generate category embeddings
def generate_category_embeddings(model, config, paths):
    """
    Generate embeddings for diagnostic categories
    """
    print("Generating category embeddings...")

    # Define diagnostic categories
    categories = ['Normal', 'MI', 'AFIB', 'Block', 'HYP']

    # Define enhanced prompts with clinical knowledge
    category_prompts = [
        "An ECG showing normal sinus rhythm with regular P waves, normal PR interval, and normal QRS complex.",
        "An ECG showing myocardial infarction with ST segment elevation, pathological Q waves, and T wave inversion.",
        "An ECG showing atrial fibrillation with irregular RR intervals, absence of P waves, and fibrillatory waves.",
        "An ECG showing heart block with prolonged PR interval, dropped QRS complexes, or complete dissociation of P waves and QRS complexes.",
        "An ECG showing ventricular hypertrophy with increased QRS amplitude, prolonged QRS duration, and secondary ST-T changes."
    ]

    # Move model to device
    model = model.to(config['device'])

    # Generate embeddings
    model.eval()
    with torch.no_grad():
        category_embeddings = model.text_encoder(category_prompts).cpu().numpy()

    # Save category embeddings
    np.save(os.path.join(paths['processed_data_dir'], 'embeddings', 'category_embeddings.npy'), category_embeddings)

    # Create a mapping dictionary
    category_mapping = {i: category for i, category in enumerate(categories)}
    with open(os.path.join(paths['processed_data_dir'], 'embeddings', 'category_mapping.json'), 'w') as f:
        json.dump(category_mapping, f)

    print("Category embeddings generated and saved.")
    return category_embeddings


In [None]:
# Fine tune the model using hierarchical loss function
def fine_tune_model(model, data_splits, embeddings, config, paths):
    """
    Fine-tune the model using hierarchical loss function
    """
    print("Fine-tuning model with hierarchical loss...")

    # Define hierarchical loss function
    def hierarchical_loss(similarities, labels, hierarchy_matrix):
        """
        Hierarchical loss function that accounts for relationships between diagnostic categories
        """
        # Standard cross-entropy loss
        ce_loss = F.cross_entropy(similarities, labels)

        # Hierarchical component - penalize less for confusions within the same hierarchy
        batch_size = similarities.size(0)
        pred_probs = F.softmax(similarities, dim=1)

        # For each sample, calculate hierarchical penalty
        hier_loss = 0.0
        for i in range(batch_size):
            true_label = labels[i].item()
            for j in range(similarities.size(1)):
                if j != true_label:
                    # Penalty is weighted by prediction probability and hierarchy relationship
                    hier_loss += pred_probs[i, j] * (1.0 - hierarchy_matrix[true_label, j])

        hier_loss = hier_loss / batch_size

        # Combined loss (weighted sum)
        loss = ce_loss + config['hierarchy_weight'] * hier_loss
        return loss

    # Define hierarchy matrix (1.0 means closely related, 0.0 means unrelated)
    # This is a simplified example - in practice, this would be derived from medical knowledge
    hierarchy_matrix = np.array([
        [1.0, 0.2, 0.2, 0.2, 0.2],  # Normal
        [0.2, 1.0, 0.3, 0.3, 0.4],  # MI
        [0.2, 0.3, 1.0, 0.4, 0.2],  # AFIB
        [0.2, 0.3, 0.4, 1.0, 0.2],  # Block
        [0.2, 0.4, 0.2, 0.2, 1.0],  # HYP
    ])

    hierarchy_matrix_tensor = torch.tensor(hierarchy_matrix, dtype=torch.float32).to(config['device'])

    # Prepare data
    X_train = data_splits['X_train']
    train_diagnostics = data_splits['train_diagnostics']

    # Map diagnostics to categories
    train_labels = []
    for diagnostic in train_diagnostics:
        if 'normal' in diagnostic.lower():
            train_labels.append(0)  # Normal
        elif 'infarct' in diagnostic.lower() or 'ischemia' in diagnostic.lower():
            train_labels.append(1)  # MI
        elif 'fibrillation' in diagnostic.lower():
            train_labels.append(2)  # AFIB
        elif 'block' in diagnostic.lower():
            train_labels.append(3)  # Block
        elif 'hypertrophy' in diagnostic.lower():
            train_labels.append(4)  # HYP
        else:
            train_labels.append(0)  # Default to Normal if unknown

    # Convert to torch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)

    # Create Dataset and DataLoader
    train_dataset = torch.utils.data.TensorDataset(X_train_tensor, train_labels_tensor)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True
    )

    # Move model to device
    model = model.to(config['device'])

    # Optimizer with lower learning rate for fine-tuning
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'] * 0.1,  # Lower learning rate for fine-tuning
        weight_decay=config['weight_decay']
    )

    # Training loop
    model.train()
    fine_tune_losses = []

    for epoch in range(config['fine_tune_epochs']):
        epoch_loss = 0.0

        with tqdm(train_loader, desc=f"Fine-tune Epoch {epoch+1}/{config['fine_tune_epochs']}") as pbar:
            for batch_idx, (ecg_batch, labels) in enumerate(pbar):
                # Move data to device
                ecg_batch = ecg_batch.to(config['device'])
                labels = labels.to(config['device'])

                # Forward pass to get ECG embeddings
                ecg_embeddings = model.ecg_encoder(ecg_batch)

                # Get category embeddings for similarity calculation
                # We load these from the embeddings dict to avoid recomputing
                category_embeddings = torch.tensor(embeddings['category_embeddings'], dtype=torch.float32).to(config['device'])

                # Calculate similarities
                similarities = torch.matmul(ecg_embeddings, category_embeddings.T) / model.temperature

                # Calculate hierarchical loss
                loss = hierarchical_loss(similarities, labels, hierarchy_matrix_tensor)

                # Backward pass and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update progress bar
                epoch_loss += loss.item()
                pbar.set_postfix(loss=epoch_loss / (batch_idx + 1))

        fine_tune_losses.append(epoch_loss / len(train_loader))

    # Save fine-tuned model
    model_path = os.path.join(paths['processed_data_dir'], 'models', 'fine_tuned_model.pth')
    torch.save(model.state_dict(), model_path)
    print(f"Fine-tuned model saved at {model_path}")

    # Plot fine-tuning loss
    plt.figure(figsize=(10, 6))
    plt.plot(fine_tune_losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Fine-tuning Loss')
    plt.savefig(os.path.join(paths['processed_data_dir'], 'fine_tuning_loss.png'))
    plt.close()

    return model

# Validate the model
def validate_model(model, data_splits, config, paths):
    """
    Validate the fine-tuned model
    """
    print("Validating model...")

    # Prepare data
    X_test = data_splits['X_test']
    test_diagnostics = data_splits['test_diagnostics']

    # Map diagnostics to categories
    test_labels = []
    for diagnostic in test_diagnostics:
        if 'normal' in diagnostic.lower():
            test_labels.append(0)  # Normal
        elif 'infarct' in diagnostic.lower() or 'ischemia' in diagnostic.lower():
            test_labels.append(1)  # MI
        elif 'fibrillation' in diagnostic.lower():
            test_labels.append(2)  # AFIB
        elif 'block' in diagnostic.lower():
            test_labels.append(3)  # Block
        elif 'hypertrophy' in diagnostic.lower():
            test_labels.append(4)  # HYP
        else:
            test_labels.append(-1)  # Unknown

    # Filter out unknown categories
    valid_indices = [i for i, label in enumerate(test_labels) if label != -1]
    if len(valid_indices) == 0:
        print("No valid samples for validation.")
        return None

    valid_X_test = X_test[valid_indices]
    valid_test_labels = [test_labels[i] for i in valid_indices]

    # Convert to torch tensors
    X_test_tensor = torch.tensor(valid_X_test, dtype=torch.float32)
    test_labels_tensor = torch.tensor(valid_test_labels, dtype=torch.long)

    # Create Dataset and DataLoader
    test_dataset = torch.utils.data.TensorDataset(X_test_tensor, test_labels_tensor)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False
    )

    # Move model to device
    model = model.to(config['device'])

    # Validation
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []

    # Define categories
    categories = ['Normal', 'MI', 'AFIB', 'Block', 'HYP']

    # Get category embeddings
    with torch.no_grad():
        # Define prompts with clinical knowledge
        category_prompts = [
            "An ECG showing normal sinus rhythm with regular P waves, normal PR interval, and normal QRS complex.",
            "An ECG showing myocardial infarction with ST segment elevation, pathological Q waves, and T wave inversion.",
            "An ECG showing atrial fibrillation with irregular RR intervals, absence of P waves, and fibrillatory waves.",
            "An ECG showing heart block with prolonged PR interval, dropped QRS complexes, or complete dissociation of P waves and QRS complexes.",
            "An ECG showing ventricular hypertrophy with increased QRS amplitude, prolonged QRS duration, and secondary ST-T changes."
        ]
        category_embeddings = model.text_encoder(category_prompts).to(config['device'])

    # Process test data
    with torch.no_grad():
        for ecg_batch, labels in tqdm(test_loader, desc="Validating"):
            # Move data to device
            ecg_batch = ecg_batch.to(config['device'])

            # Get ECG embeddings
            ecg_embeddings = model.ecg_encoder(ecg_batch)

            # Calculate similarities
            similarities = torch.matmul(ecg_embeddings, category_embeddings.T) / model.temperature
            probabilities = F.softmax(similarities, dim=1)

            # Get predictions
            predictions = torch.argmax(probabilities, dim=1)

            # Store results
            all_predictions.append(predictions.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            all_probabilities.append(probabilities.cpu().numpy())

    # Concatenate results
    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)
    all_probabilities = np.concatenate(all_probabilities)

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)

    # Calculate AUC for each category (one-vs-rest)
    auc_scores = []
    for i in range(len(categories)):
        y_true = (all_labels == i).astype(int)
        y_score = all_probabilities[:, i]
        try:
            auc = roc_auc_score(y_true, y_score)
            auc_scores.append(auc)
        except:
            auc_scores.append(np.nan)

    # Print results
    print(f"Validation accuracy: {accuracy:.4f}")
    print("AUC scores by category:")
    for i, category in enumerate(categories):
        if not np.isnan(auc_scores[i]):
            print(f"  {category}: {auc_scores[i]:.4f}")

    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=categories, yticklabels=categories)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Validation Confusion Matrix')
    plt.savefig(os.path.join(paths['processed_data_dir'], 'validation_confusion_matrix.png'))
    plt.close()

    return {
        'accuracy': accuracy,
        'auc_scores': auc_scores,
        'categories': categories
    }

# Main function to run the entire workflow
def main():
    # Configuration
    config = {
        'batch_size': 32,
        'learning_rate': 2e-4,
        'weight_decay': 1e-5,
        'epochs': 30,
        'fine_tune_epochs': 10,
        'temperature': 0.07,
        'dropout_prob': 0.1,
        'embedding_dim': 768,
        'hierarchy_weight': 0.5,
        'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")
    }

    print(f"Using device: {config['device']}")

    # 1. Setup dataset paths
    paths = setup_dataset_paths()

    # 2. Check if files exist
    if not check_files_exist(paths):
        return

    # 3. Create directory for processed data
    paths = create_directory(paths)

    # 4. Load CSV files
    record_df, report_df = load_csv_files(paths)

    # 5. Initialize the dataset object and validate
    dataset = ECGDataset(record_df, report_df, paths)

    # 6. Process reports and concatenate them
    dataset = process_reports(dataset)

    # 7. Assign new indexes to the reports
    dataset = assign_indexes(dataset)

    # 8. Verify ECG bin file paths
    verify_ecg_bin_paths(dataset, paths)

    # 9. Filter records and reports
    dataset = filter_records_and_reports(dataset)

    # 10. Reshape data to numpy arrays
    ecg_signals, report_texts, diagnostic_texts, ids = reshape_data_to_array(dataset, paths)

    # 11. Split data into training and testing sets
    data_splits = split_data(ecg_signals, report_texts, diagnostic_texts, ids)

    # 12. Visualize metadata and shapes
    visualize_metadata(data_splits, paths)

    # 13. Initialize the MERL model
    model = MERL(config)

    # 14. Visualize the model
    model = visualize_model(model)

    # 15. Test with dummy input
    check_dummy_input(model, config['device'])

    # 16. Define loss functions
    cma_loss_fn, uma_loss_fn = define_loss_functions()

    # 17. Train the model
    model, train_losses = train_model(model, data_splits, config, paths)

    # 18. Visualize embeddings with t-SNE
    visualize_embeddings(model, data_splits, config, paths)

    # 19. Evaluate model
    eval_results = evaluate_model(model, data_splits, config, paths)

    # 20. Perform zero-shot classification
    zero_shot_results = perform_zero_shot(model, data_splits, config, paths)

    # 21. Load pretrained model (in case we want to start from a checkpoint)
    model = load_pretrained_model(model, paths)

    # 22. Preprocess and generate memory for ECG signals
    data_splits = preprocess_and_generate_memory(data_splits, paths)

    # 23. Check for null values
    data_splits = check_for_null_values(data_splits)

    # 24. Generate embeddings
    embeddings = generate_embeddings(model, data_splits, config, paths)

    # 25. Generate category embeddings
    category_embeddings = generate_category_embeddings(model, config, paths)
    embeddings['category_embeddings'] = category_embeddings

    # 26. Fine-tune the model using hierarchical loss function
    model = fine_tune_model(model, data_splits, embeddings, config, paths)

    # 27. Validate the fine-tuned model
    validation_results = validate_model(model, data_splits, config, paths)

    print("MERL model implementation complete!")

    # Report overall performance
    print("\nPerformance Summary:")
    print(f"Zero-shot classification accuracy: {zero_shot_results.get('accuracy', 'N/A')}")
    print(f"Fine-tuned model accuracy: {validation_results.get('accuracy', 'N/A') if validation_results else 'N/A'}")

# Run the main function if this script is executed directly
if __name__ == "__main__":
    main()


NameError: name 'torch' is not defined