<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May242.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Upload your ADReSSo21-diagnosis-train.tgz file to Colab
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [1]:
# Dementia Detection from Speech and Transcripts using Transformers
# Complete Implementation for Google Colab

# ============================================================================
# STEP 1: INSTALL REQUIRED PACKAGES AND SETUP
# ============================================================================

!pip install transformers torch torchvision torchaudio librosa pandas scikit-learn matplotlib seaborn numpy

import os
import tarfile
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import librosa.display
from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# STEP 2: DATA EXTRACTION AND PREPROCESSING
# ============================================================================

class DataExtractor:
    def __init__(self, archive_path, extract_to):
        self.archive_path = archive_path
        self.extract_to = extract_to

    def extract_archive(self):
        """Extract the ADReSSo21-diagnosis-train.tgz archive"""
        print("Extracting archive...")
        with tarfile.open(self.archive_path, 'r:gz') as tar:
            tar.extractall(self.extract_to)
        print(f"Archive extracted to {self.extract_to}")

    def collect_data(self):
        """Collect data from extracted files"""
        base_path = os.path.join(self.extract_to, "ADReSSo21/diagnosis/train")

        # Paths for audio and segmentation files
        audio_ad_path = os.path.join(base_path, "audio/ad")
        audio_cn_path = os.path.join(base_path, "audio/cn")
        seg_ad_path = os.path.join(base_path, "segmentation/ad")
        seg_cn_path = os.path.join(base_path, "segmentation/cn")

        data_samples = []

        # Collect AD samples
        if os.path.exists(audio_ad_path) and os.path.exists(seg_ad_path):
            for audio_file in os.listdir(audio_ad_path):
                if audio_file.endswith('.wav'):
                    participant_id = audio_file.replace('.wav', '')
                    seg_file = f"{participant_id}.csv"

                    if os.path.exists(os.path.join(seg_ad_path, seg_file)):
                        data_samples.append({
                            'audio_path': os.path.join(audio_ad_path, audio_file),
                            'transcript_path': os.path.join(seg_ad_path, seg_file),
                            'label': 1,  # AD = 1
                            'participant_id': participant_id,
                            'class_name': 'ad'
                        })

        # Collect CN (Control) samples
        if os.path.exists(audio_cn_path) and os.path.exists(seg_cn_path):
            for audio_file in os.listdir(audio_cn_path):
                if audio_file.endswith('.wav'):
                    participant_id = audio_file.replace('.wav', '')
                    seg_file = f"{participant_id}.csv"

                    if os.path.exists(os.path.join(seg_cn_path, seg_file)):
                        data_samples.append({
                            'audio_path': os.path.join(audio_cn_path, audio_file),
                            'transcript_path': os.path.join(seg_cn_path, seg_file),
                            'label': 0,  # CN = 0
                            'participant_id': participant_id,
                            'class_name': 'cn'
                        })

        print(f"Collected {len(data_samples)} samples")
        ad_count = sum(1 for s in data_samples if s['label'] == 1)
        cn_count = sum(1 for s in data_samples if s['label'] == 0)
        print(f"AD samples: {ad_count}, CN samples: {cn_count}")

        return data_samples

# ============================================================================
# STEP 3: AUDIO FEATURE EXTRACTION
# ============================================================================

class AudioFeatureExtractor:
    def __init__(self, n_mels=224, n_fft=2048, hop_length=1024, target_length=224):
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.target_length = target_length

    def extract_log_mel_spectrogram(self, audio_path):
        """Extract Log-Mel spectrogram with delta and delta-delta features"""
        try:
            # Load audio file
            y, sr = librosa.load(audio_path, sr=22050)

            # Extract Mel spectrogram
            mel_spec = librosa.feature.melspectrogram(
                y=y, sr=sr, n_mels=self.n_mels,
                n_fft=self.n_fft, hop_length=self.hop_length
            )

            # Convert to log scale
            log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

            # Calculate delta and delta-delta features
            delta = librosa.feature.delta(log_mel_spec)
            delta_delta = librosa.feature.delta(log_mel_spec, order=2)

            # Stack to create 3-channel image
            features = np.stack([log_mel_spec, delta, delta_delta], axis=0)

            # Resize to target length (for ViT input)
            if features.shape[2] > self.target_length:
                features = features[:, :, :self.target_length]
            elif features.shape[2] < self.target_length:
                # Pad with zeros
                pad_width = self.target_length - features.shape[2]
                features = np.pad(features, ((0, 0), (0, 0), (0, pad_width)), mode='constant')

            return features

        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            return np.zeros((3, self.n_mels, self.target_length))

# ============================================================================
# STEP 4: TRANSCRIPT PROCESSING
# ============================================================================

class TranscriptProcessor:
    def __init__(self, tokenizer_name='bert-base-uncased', max_length=512):
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
        self.max_length = max_length

    def extract_transcript_from_csv(self, csv_path):
        """Extract transcript text from segmentation CSV"""
        try:
            df = pd.read_csv(csv_path)

            # The CSV might contain different column names
            # Common patterns: 'transcript', 'text', 'utterance', etc.
            text_columns = ['transcript', 'text', 'utterance', 'speech', 'content']
            transcript_text = ""

            for col in text_columns:
                if col in df.columns:
                    transcript_text = " ".join(df[col].dropna().astype(str))
                    break

            # If no text column found, try to construct from available data
            if not transcript_text and 'speaker' in df.columns:
                # Sometimes transcript is distributed across rows
                participant_rows = df[df['speaker'].str.contains('participant|patient|PAR', case=False, na=False)]
                if not participant_rows.empty and len(df.columns) > 3:
                    # Try the last column as it might contain text
                    last_col = df.columns[-1]
                    transcript_text = " ".join(participant_rows[last_col].dropna().astype(str))

            # Fallback: use all non-numeric content
            if not transcript_text:
                text_parts = []
                for col in df.columns:
                    if df[col].dtype == 'object':
                        text_parts.extend(df[col].dropna().astype(str).tolist())
                transcript_text = " ".join(text_parts)

            return transcript_text if transcript_text else "No transcript available"

        except Exception as e:
            print(f"Error processing transcript {csv_path}: {e}")
            return "Error reading transcript"

    def tokenize_text(self, text):
        """Tokenize text using BERT tokenizer"""
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'token_type_ids': encoding['token_type_ids'].squeeze()
        }

# ============================================================================
# STEP 5: DATASET CLASS
# ============================================================================

class DementiaDataset(Dataset):
    def __init__(self, data_samples, audio_extractor, transcript_processor):
        self.data_samples = data_samples
        self.audio_extractor = audio_extractor
        self.transcript_processor = transcript_processor

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        # Extract audio features
        audio_features = self.audio_extractor.extract_log_mel_spectrogram(sample['audio_path'])
        audio_features = torch.FloatTensor(audio_features)

        # Extract transcript features
        transcript_text = self.transcript_processor.extract_transcript_from_csv(sample['transcript_path'])
        transcript_tokens = self.transcript_processor.tokenize_text(transcript_text)

        return {
            'audio_features': audio_features,
            'input_ids': transcript_tokens['input_ids'],
            'attention_mask': transcript_tokens['attention_mask'],
            'token_type_ids': transcript_tokens['token_type_ids'],
            'label': torch.LongTensor([sample['label']]),
            'participant_id': sample['participant_id']
        }

# ============================================================================
# STEP 6: MODEL ARCHITECTURES
# ============================================================================

class CrossModalAttention(nn.Module):
    def __init__(self, hidden_size):
        super(CrossModalAttention, self).__init__()
        self.hidden_size = hidden_size
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, query, key, value):
        attended, _ = self.attention(query, key, value)
        return self.norm(attended + query)

class GatedMultimodalUnit(nn.Module):
    def __init__(self, input_size):
        super(GatedMultimodalUnit, self).__init__()
        self.linear_text = nn.Linear(input_size, input_size)
        self.linear_audio = nn.Linear(input_size, input_size)
        self.linear_gate = nn.Linear(input_size * 2, input_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, text_features, audio_features):
        h_text = self.tanh(self.linear_text(text_features))
        h_audio = self.tanh(self.linear_audio(audio_features))

        concat_features = torch.cat([h_text, h_audio], dim=-1)
        gate = self.sigmoid(self.linear_gate(concat_features))

        gated_features = gate * h_text + (1 - gate) * h_audio
        return gated_features

class MultimodalDementiaClassifier(nn.Module):
    def __init__(self, fusion_method='crossmodal', bert_model='bert-base-uncased',
                 vit_model='google/vit-base-patch16-224-in21k', num_classes=2):
        super(MultimodalDementiaClassifier, self).__init__()

        self.fusion_method = fusion_method

        # Text encoder (BERT)
        self.bert = BertModel.from_pretrained(bert_model)
        self.bert_hidden_size = self.bert.config.hidden_size

        # Audio encoder (ViT)
        self.vit = ViTModel.from_pretrained(vit_model)
        self.vit_hidden_size = self.vit.config.hidden_size

        # Projection layers to match dimensions
        self.text_projection = nn.Linear(self.bert_hidden_size, 512)
        self.audio_projection = nn.Linear(self.vit_hidden_size, 512)

        # Fusion layers
        if fusion_method == 'concatenation':
            self.classifier = nn.Sequential(
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, num_classes)
            )
        elif fusion_method == 'gmu':
            self.gmu = GatedMultimodalUnit(512)
            self.classifier = nn.Linear(512, num_classes)
        elif fusion_method == 'crossmodal':
            self.cross_attention_text_to_audio = CrossModalAttention(512)
            self.cross_attention_audio_to_text = CrossModalAttention(512)
            self.classifier = nn.Sequential(
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, num_classes)
            )

    def forward(self, input_ids, attention_mask, token_type_ids, audio_features):
        # Text encoding
        bert_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        text_features = bert_outputs.pooler_output  # [CLS] token
        text_features = self.text_projection(text_features)

        # Audio encoding
        # Reshape audio features for ViT (batch_size, channels, height, width)
        batch_size = audio_features.size(0)
        audio_features = audio_features.view(batch_size, 3, 224, 224)

        vit_outputs = self.vit(pixel_values=audio_features)
        audio_features = vit_outputs.pooler_output
        audio_features = self.audio_projection(audio_features)

        # Fusion
        if self.fusion_method == 'concatenation':
            fused_features = torch.cat([text_features, audio_features], dim=-1)
            logits = self.classifier(fused_features)

        elif self.fusion_method == 'gmu':
            fused_features = self.gmu(text_features, audio_features)
            logits = self.classifier(fused_features)

        elif self.fusion_method == 'crossmodal':
            # Add sequence dimension for attention
            text_seq = text_features.unsqueeze(1)  # (batch_size, 1, hidden_size)
            audio_seq = audio_features.unsqueeze(1)  # (batch_size, 1, hidden_size)

            # Cross-modal attention
            text_attended = self.cross_attention_text_to_audio(text_seq, audio_seq, audio_seq)
            audio_attended = self.cross_attention_audio_to_text(audio_seq, text_seq, text_seq)

            # Concatenate and classify
            fused_features = torch.cat([
                text_attended.squeeze(1),
                audio_attended.squeeze(1)
            ], dim=-1)
            logits = self.classifier(fused_features)

        return logits

# ============================================================================
# STEP 7: TRAINING LOOP
# ============================================================================

class DementiaTrainer:
    def __init__(self, model, train_loader, val_loader, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device

        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-5)
        self.criterion = nn.CrossEntropyLoss()
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=3, factor=0.5
        )

        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []

    def train_epoch(self):
        self.model.train()
        total_loss = 0

        for batch in tqdm(self.train_loader, desc="Training"):
            self.optimizer.zero_grad()

            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            audio_features = batch['audio_features'].to(self.device)
            labels = batch['label'].squeeze().to(self.device)

            # Forward pass
            logits = self.model(input_ids, attention_mask, token_type_ids, audio_features)
            loss = self.criterion(logits, labels)

            # Backward pass
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss / len(self.train_loader)

    def validate_epoch(self):
        self.model.eval()
        total_loss = 0
        all_predictions = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                token_type_ids = batch['token_type_ids'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                labels = batch['label'].squeeze().to(self.device)

                # Forward pass
                logits = self.model(input_ids, attention_mask, token_type_ids, audio_features)
                loss = self.criterion(logits, labels)

                total_loss += loss.item()

                # Get predictions
                predictions = torch.argmax(logits, dim=1)
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_loss = total_loss / len(self.val_loader)
        accuracy = accuracy_score(all_labels, all_predictions)

        return avg_loss, accuracy, all_predictions, all_labels

    def train(self, num_epochs=20, early_stopping_patience=6):
        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")

            # Train
            train_loss = self.train_epoch()

            # Validate
            val_loss, val_accuracy, val_predictions, val_labels = self.validate_epoch()

            # Update scheduler
            self.scheduler.step(val_loss)

            # Store metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_accuracy)

            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_loss:.4f}")
            print(f"Val Accuracy: {val_accuracy:.4f}")

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                # Save best model
                torch.save(self.model.state_dict(), 'best_model.pth')
            else:
                patience_counter += 1

            if patience_counter >= early_stopping_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        # Load best model
        self.model.load_state_dict(torch.load('best_model.pth'))
        return val_predictions, val_labels

    def plot_training_history(self):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

        # Loss plot
        ax1.plot(self.train_losses, label='Train Loss')
        ax1.plot(self.val_losses, label='Val Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()

        # Accuracy plot
        ax2.plot(self.val_accuracies, label='Val Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Validation Accuracy')
        ax2.legend()

        plt.tight_layout()
        plt.show()

# ============================================================================
# STEP 8: EVALUATION METRICS
# ============================================================================

def evaluate_model(y_true, y_pred, class_names=['CN', 'AD']):
    """Comprehensive evaluation of model performance"""

    # Basic metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')

    # Per-class metrics
    precision_per_class = precision_score(y_true, y_pred, average=None)
    recall_per_class = recall_score(y_true, y_pred, average=None)
    f1_per_class = f1_score(y_true, y_pred, average=None)

    print("="*50)
    print("MODEL EVALUATION RESULTS")
    print("="*50)
    print(f"Overall Accuracy: {accuracy:.4f}")
    print(f"Overall Precision: {precision:.4f}")
    print(f"Overall Recall: {recall:.4f}")
    print(f"Overall F1-Score: {f1:.4f}")
    print("\nPer-Class Metrics:")
    for i, class_name in enumerate(class_names):
        print(f"{class_name} - Precision: {precision_per_class[i]:.4f}, "
              f"Recall: {recall_per_class[i]:.4f}, F1: {f1_per_class[i]:.4f}")

    # Detailed classification report
    print("\nDetailed Classification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names))

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'precision_per_class': precision_per_class,
        'recall_per_class': recall_per_class,
        'f1_per_class': f1_per_class
    }

# ============================================================================
# STEP 9: MAIN EXECUTION PIPELINE
# ============================================================================

def main():
    print("Starting Dementia Detection Pipeline...")

    # Configuration
    ARCHIVE_PATH = "/content/ADReSSo21-diagnosis-train.tgz"  # Update this path
    EXTRACT_TO = "/content/ADReSSo_extracted"
    BATCH_SIZE = 8
    NUM_EPOCHS = 20
    FUSION_METHOD = 'crossmodal'  # Options: 'concatenation', 'gmu', 'crossmodal'

    # Step 1: Extract and collect data
    print("\n" + "="*50)
    print("STEP 1: DATA EXTRACTION AND COLLECTION")
    print("="*50)

    if os.path.exists(ARCHIVE_PATH):
        extractor = DataExtractor(ARCHIVE_PATH, EXTRACT_TO)

        if not os.path.exists(EXTRACT_TO):
            extractor.extract_archive()

        data_samples = extractor.collect_data()
    else:
        print(f"Archive not found at {ARCHIVE_PATH}")
        print("Please upload the ADReSSo21-diagnosis-train.tgz file to Colab")
        return

    if len(data_samples) == 0:
        print("No data samples found. Please check the archive and extraction.")
        return

    # Step 2: Initialize processors
    print("\n" + "="*50)
    print("STEP 2: INITIALIZING PROCESSORS")
    print("="*50)

    audio_extractor = AudioFeatureExtractor()
    transcript_processor = TranscriptProcessor()

    # Step 3: Split data
    print("\n" + "="*50)
    print("STEP 3: DATA SPLITTING")
    print("="*50)

    train_samples, val_samples = train_test_split(
        data_samples, test_size=0.35, random_state=42,
        stratify=[s['label'] for s in data_samples]
    )

    print(f"Training samples: {len(train_samples)}")
    print(f"Validation samples: {len(val_samples)}")

    # Step 4: Create datasets and dataloaders
    print("\n" + "="*50)
    print("STEP 4: CREATING DATASETS")
    print("="*50)

    train_dataset = DementiaDataset(train_samples, audio_extractor, transcript_processor)
    val_dataset = DementiaDataset(val_samples, audio_extractor, transcript_processor)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print("Datasets created successfully!")

    # Step 5: Initialize model
    print("\n" + "="*50)
    print("STEP 5: INITIALIZING MODEL")
    print("="*50)

    model = MultimodalDementiaClassifier(fusion_method=FUSION_METHOD)
    print(f"Model initialized with {FUSION_METHOD} fusion method")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Step 6: Training
    print("\n" + "="*50)
    print("STEP 6: TRAINING")
    print("="*50)

    trainer = DementiaTrainer(model, train_loader, val_loader, device)
    val_predictions, val_labels = trainer.train(num_epochs=NUM_EPOCHS)

    # Step 7: Evaluation
    print("\n" + "="*50)
    print("STEP 7: FINAL EVALUATION")
    print("="*50)

    metrics = evaluate_model(val_labels, val_predictions)

    # Step 8: Visualizations
    print("\n" + "="*50)
    print("STEP 8: VISUALIZATION")
    print("="*50)

    trainer.plot_training_history()

    print("\nTraining completed successfully!")
    return model, trainer, metrics

# ============================================================================
# STEP 10: COMPARISON OF DIFFERENT FUSION METHODS
# ============================================================================

def compare_fusion_methods():
    """Compare different fusion methods"""
    print("Comparing different fusion methods...")

    fusion_methods = ['concatenation', 'gmu', 'crossmodal']
    results = {}

    for method in fusion_methods:
        print(f"\nTraining with {method} fusion...")

        # Initialize new model
        model = MultimodalDementiaClassifier(fusion_method=method)
        trainer = DementiaTrainer(model, train_loader, val_loader, device)

        # Train with fewer epochs for comparison
        val_predictions, val_labels = trainer.train(num_epochs=10)

        # Evaluate
        metrics = evaluate_model(val_labels, val_predictions)
        results[method] = metrics

    # Compare results
    print("\n" + "="*50)
    print("FUSION METHOD COMPARISON")
    print("="*50)

    for method, metrics in results.items():
        print(f"{method.upper()}:")
        print(f"  Accuracy: {metrics['accuracy']:.4f}")
        print(f"  F1-Score: {metrics['f1']:.4f}")
        print()

    return results

# Run the main pipeline
if __name__ == "__main__":
    # Upload your ADReSSo21-diagnosis-train.tgz file to Colab first
    main()

    # Uncomment to compare fusion methods
    # compare_fusion_methods()

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5