<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/NEURAL_ARCHITECTURE_SEARCH_WITH_MULTIMODAL_FUSION_METHODS4DIAGNOSING_DEMENTIA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Workflow Overview

- Setup and Dependencies Installation
- Dataset Extraction and Exploration
- Audio Preprocessing and Feature Extraction
- Text Generation (ASR) and Linguistic Feature Extraction
- DARTS Neural Architecture Search Implementation
- BERT Text Processing
- Multimodal Fusion Implementation
- Model Training and Evaluation
- Testing and Validation

In [4]:
# AD Detection Starter Script for Google Colab
# Run this first to test dataset loading and basic setup

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install transformers torch torchaudio librosa speechrecognition pydub scikit-learn

import os
import tarfile
import glob
import librosa
import numpy as np
import pandas as pd
from pathlib import Path

# Step 1: Check and extract datasets
def setup_datasets():
    base_path = "/content/drive/MyDrive/Voice/"

    # Check if files exist
    files_to_check = [
        "ADReSSo21-diagnosis-train.tgz",
        "ADReSSo21-progression-test.tgz",
        "ADReSSo21-progression-train.tgz"
    ]

    print("Checking dataset files...")
    for file in files_to_check:
        full_path = os.path.join(base_path, file)
        if os.path.exists(full_path):
            print(f"✓ Found: {file}")
        else:
            print(f"✗ Missing: {file}")

    # Extract datasets
    print("\nExtracting datasets...")
    for file in files_to_check:
        archive_path = os.path.join(base_path, file)
        extract_path = os.path.join(base_path, file.replace('.tgz', ''))

        if os.path.exists(archive_path) and not os.path.exists(extract_path):
            print(f"Extracting {file}...")
            try:
                with tarfile.open(archive_path, 'r:gz') as tar:
                    tar.extractall(extract_path)
                print(f"✓ Extracted to {extract_path}")
            except Exception as e:
                print(f"✗ Error extracting {file}: {e}")
        elif os.path.exists(extract_path):
            print(f"✓ Already extracted: {file}")

# Step 2: Explore dataset structure
def explore_dataset_structure():
    base_path = "/content/drive/MyDrive/Voice/"

    print("Dataset structure:")
    for root, dirs, files in os.walk(base_path):
        level = root.replace(base_path, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files[:5]:  # Show first 5 files only
            print(f"{subindent}{file}")
        if len(files) > 5:
            print(f"{subindent}... and {len(files) - 5} more files")

# Step 3: Find and analyze audio files
def find_audio_files():
    base_path = "/content/drive/MyDrive/Voice/"
    audio_extensions = ['.wav', '.mp3', '.flac', '.m4a']

    audio_files = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if any(file.lower().endswith(ext) for ext in audio_extensions):
                audio_files.append(os.path.join(root, file))

    print(f"\nFound {len(audio_files)} audio files")

    if audio_files:
        print("\nSample audio files:")
        for i, file in enumerate(audio_files[:5]):
            print(f"{i+1}. {file}")

        # Analyze first audio file
        if len(audio_files) > 0:
            print(f"\nAnalyzing first audio file: {audio_files[0]}")
            try:
                y, sr = librosa.load(audio_files[0], duration=10)  # Load first 10 seconds
                print(f"Sample rate: {sr} Hz")
                print(f"Duration: {len(y)/sr:.2f} seconds")
                print(f"Audio shape: {y.shape}")
            except Exception as e:
                print(f"Error loading audio: {e}")

    return audio_files

# Step 4: Check for label information
def check_labels():
    base_path = "/content/drive/MyDrive/Voice/"

    # Look for CSV files or text files that might contain labels
    label_files = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.lower().endswith(('.csv', '.txt', '.tsv', '.json')):
                label_files.append(os.path.join(root, file))

    print(f"\nFound {len(label_files)} potential label files:")
    for file in label_files:
        print(f"- {file}")

    # Try to read label files
    for file in label_files[:3]:  # Check first 3 files
        try:
            if file.endswith('.csv'):
                df = pd.read_csv(file)
                print(f"\n{file} (CSV):")
                print(f"Shape: {df.shape}")
                print(f"Columns: {list(df.columns)}")
                print(df.head())
            elif file.endswith('.txt'):
                with open(file, 'r') as f:
                    content = f.read()[:500]  # First 500 characters
                print(f"\n{file} (TXT):")
                print(content)
        except Exception as e:
            print(f"Error reading {file}: {e}")

# Step 5: Basic audio feature extraction test
def test_audio_processing():
    print("\nTesting audio processing...")

    # Find an audio file to test
    audio_files = find_audio_files()
    if not audio_files:
        print("No audio files found for testing")
        return

    test_file = audio_files[0]
    print(f"Testing with: {test_file}")

    try:
        # Load audio
        y, sr = librosa.load(test_file, sr=16000, duration=30)  # 30 seconds max

        # Extract basic features
        print("Extracting features...")

        # MFCCs
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
        print(f"MFCCs shape: {mfccs.shape}")

        # Spectral features
        spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)
        print(f"Spectral centroids shape: {spectral_centroids.shape}")

        # Zero crossing rate
        zcr = librosa.feature.zero_crossing_rate(y)
        print(f"Zero crossing rate shape: {zcr.shape}")

        print("✓ Audio processing test successful!")

    except Exception as e:
        print(f"✗ Audio processing test failed: {e}")

# Main execution
def run_startup_checks():
    print("=== AD Detection Model Setup ===\n")

    # Run all checks
    setup_datasets()
    explore_dataset_structure()
    audio_files = find_audio_files()
    check_labels()
    test_audio_processing()

    print("\n=== Setup Complete ===")
    print(f"Ready to proceed with model implementation!")
    print(f"Found {len(audio_files) if 'audio_files' in locals() else 0} audio files to work with")

# Run the startup checks
if __name__ == "__main__":
    run_startup_checks()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Collecting speechrecognition
  Downloading speechrecognition-3.14.3-py3-none-any.whl.metadata (30 kB)
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidi

In [3]:
# Multimodal Alzheimer's Detection Model Implementation
# Based on BERT + DARTS Architecture

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import pandas as pd
import librosa
import speech_recognition as sr
from transformers import BertTokenizer, BertModel, BertForSequenceClassification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
import tarfile
import glob
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Step 1: Setup and Installation
def install_dependencies():
    """Install required packages"""
    import subprocess
    import sys

    packages = [
        'torch torchvision torchaudio',
        'transformers',
        'librosa',
        'SpeechRecognition',
        'pydub',
        'scikit-learn',
        'matplotlib',
        'seaborn',
        'pandas',
        'numpy',
        'tqdm'
    ]

    for package in packages:
        try:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
        except:
            print(f"Failed to install {package}")

# Step 2: Dataset Extraction and Loading
class ADReSSoDatasetLoader:
    def __init__(self, base_path="/content/drive/MyDrive/Voice/"):
        self.base_path = base_path
        self.train_diagnosis_path = os.path.join(base_path, "ADReSSo21-diagnosis-train.tgz")
        self.train_progression_path = os.path.join(base_path, "ADReSSo21-progression-train.tgz")
        self.test_progression_path = os.path.join(base_path, "ADReSSo21-progression-test.tgz")

        # Extract datasets
        self.extract_datasets()

    def extract_datasets(self):
        """Extract all dataset archives"""
        datasets = [
            (self.train_diagnosis_path, "diagnosis_train"),
            (self.train_progression_path, "progression_train"),
            (self.test_progression_path, "progression_test")
        ]

        for archive_path, folder_name in datasets:
            if os.path.exists(archive_path):
                extract_path = os.path.join(self.base_path, folder_name)
                if not os.path.exists(extract_path):
                    print(f"Extracting {archive_path}...")
                    with tarfile.open(archive_path, 'r:gz') as tar:
                        tar.extractall(extract_path)
                    print(f"Extracted to {extract_path}")
                else:
                    print(f"{folder_name} already extracted")

    def load_audio_files(self):
        """Load all audio files and their labels"""
        audio_files = []
        labels = []

        # Load diagnosis training data
        diagnosis_path = os.path.join(self.base_path, "diagnosis_train")
        if os.path.exists(diagnosis_path):
            # Look for audio files and corresponding labels
            for root, dirs, files in os.walk(diagnosis_path):
                for file in files:
                    if file.endswith(('.wav', '.mp3', '.flac')):
                        audio_path = os.path.join(root, file)
                        # Extract label from filename or folder structure
                        # Assuming AD/Control classification from folder or filename
                        if 'ad' in file.lower() or 'ad' in root.lower():
                            label = 1  # AD patient
                        else:
                            label = 0  # Control

                        audio_files.append(audio_path)
                        labels.append(label)

        return audio_files, labels

# Step 3: Audio Feature Extraction
class AudioFeatureExtractor:
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate

    def extract_acoustic_features(self, audio_path):
        """Extract comprehensive acoustic features"""
        try:
            # Load audio
            y, sr = librosa.load(audio_path, sr=self.sample_rate)

            # Extract various acoustic features
            features = {}

            # Spectral features
            features['mfcc'] = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
            features['spectral_centroid'] = librosa.feature.spectral_centroid(y=y, sr=sr)
            features['spectral_rolloff'] = librosa.feature.spectral_rolloff(y=y, sr=sr)
            features['spectral_bandwidth'] = librosa.feature.spectral_bandwidth(y=y, sr=sr)
            features['zero_crossing_rate'] = librosa.feature.zero_crossing_rate(y)

            # Prosodic features
            features['tempo'], _ = librosa.beat.beat_track(y=y, sr=sr)
            features['chroma'] = librosa.feature.chroma_stft(y=y, sr=sr)
            features['tonnetz'] = librosa.feature.tonnetz(y=librosa.effects.harmonic(y), sr=sr)

            # Aggregate features (mean, std, etc.)
            aggregated_features = []
            for key, value in features.items():
                if key == 'tempo':
                    aggregated_features.append(value)
                else:
                    if value.ndim > 1:
                        aggregated_features.extend([
                            np.mean(value, axis=1),
                            np.std(value, axis=1),
                            np.max(value, axis=1),
                            np.min(value, axis=1)
                        ])
                    else:
                        aggregated_features.extend([
                            np.mean(value),
                            np.std(value),
                            np.max(value),
                            np.min(value)
                        ])

            # Flatten all features
            feature_vector = np.concatenate([f.flatten() if hasattr(f, 'flatten') else [f]
                                          for f in aggregated_features])

            return feature_vector

        except Exception as e:
            print(f"Error extracting features from {audio_path}: {e}")
            return np.zeros(200)  # Return zero vector if extraction fails

# Step 4: Speech-to-Text and Linguistic Feature Extraction
class SpeechToTextProcessor:
    def __init__(self):
        self.recognizer = sr.Recognizer()

    def audio_to_text(self, audio_path):
        """Convert audio to text using speech recognition"""
        try:
            # Convert audio to wav if needed
            audio_data, sr = librosa.load(audio_path, sr=16000)

            # Save as temporary wav file
            temp_path = "/tmp/temp_audio.wav"
            librosa.output.write_wav(temp_path, audio_data, sr)

            # Perform speech recognition
            with sr.AudioFile(temp_path) as source:
                audio = self.recognizer.record(source)
                text = self.recognizer.recognize_google(audio)

            # Clean up temp file
            if os.path.exists(temp_path):
                os.remove(temp_path)

            return text

        except Exception as e:
            print(f"Speech recognition failed for {audio_path}: {e}")
            return "unable to transcribe audio"

    def extract_linguistic_features(self, text):
        """Extract linguistic features from text"""
        if not text or len(text.strip()) == 0:
            return {
                'word_count': 0,
                'sentence_count': 0,
                'avg_word_length': 0,
                'avg_sentence_length': 0,
                'pause_count': 0,
                'filler_count': 0
            }

        words = text.split()
        sentences = text.split('.')

        # Count fillers and pauses
        fillers = ['um', 'uh', 'er', 'ah', 'hmm']
        filler_count = sum(1 for word in words if word.lower() in fillers)

        features = {
            'word_count': len(words),
            'sentence_count': len(sentences),
            'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
            'avg_sentence_length': np.mean([len(sent.split()) for sent in sentences]) if sentences else 0,
            'pause_count': text.count('...') + text.count(','),
            'filler_count': filler_count
        }

        return features

# Step 5: DARTS Implementation
class DARTSCell(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DARTSCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Define possible operations
        self.operations = nn.ModuleList([
            nn.Identity(),
            nn.ReLU(),
            nn.Conv1d(input_dim, output_dim, 1),
            nn.Conv1d(input_dim, output_dim, 3, padding=1),
            nn.MaxPool1d(3, stride=1, padding=1),
            nn.AvgPool1d(3, stride=1, padding=1)
        ])

        # Architecture parameters (alpha)
        self.alpha = nn.Parameter(torch.randn(len(self.operations)))

    def forward(self, x):
        # Apply softmax to architecture parameters
        weights = F.softmax(self.alpha, dim=0)

        # Weighted combination of all operations
        output = sum(w * op(x) for w, op in zip(weights, self.operations))
        return output

class DARTSNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=3):
        super(DARTSNetwork, self).__init__()

        self.input_projection = nn.Linear(input_dim, hidden_dim)

        # Stack multiple DARTS cells
        self.cells = nn.ModuleList([
            DARTSCell(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])

        self.output_projection = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        # x shape: (batch_size, input_dim)
        x = self.input_projection(x)
        x = x.unsqueeze(-1)  # Add sequence dimension for conv1d

        for cell in self.cells:
            x = cell(x)

        x = x.squeeze(-1)  # Remove sequence dimension
        x = self.output_projection(x)

        return x

# Step 6: BERT Text Processing
class BERTProcessor:
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name)
        self.model.eval()

    def encode_text(self, text, max_length=512):
        """Encode text using BERT"""
        inputs = self.tokenizer(
            text,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        with torch.no_grad():
            outputs = self.model(**inputs)
            # Use CLS token representation
            cls_embedding = outputs.last_hidden_state[:, 0, :]  # CLS token

        return cls_embedding

# Step 7: Multimodal Fusion Model
class MultimodalFusionModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, num_classes=2):
        super(MultimodalFusionModel, self).__init__()

        # Audio processing with DARTS
        self.audio_darts = DARTSNetwork(audio_dim, hidden_dim)

        # Text processing
        self.text_projection = nn.Linear(text_dim, hidden_dim)

        # Fusion methods
        self.fusion_method = 'concatenation'  # Can be changed to other methods

        if self.fusion_method == 'concatenation':
            fusion_dim = hidden_dim * 2
        elif self.fusion_method == 'tucker':
            fusion_dim = hidden_dim
        elif self.fusion_method == 'mfb':
            fusion_dim = hidden_dim
        elif self.fusion_method == 'block':
            fusion_dim = hidden_dim
        else:
            fusion_dim = hidden_dim * 2

        # Final classification layers
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, audio_features, text_features):
        # Process audio features through DARTS
        audio_repr = self.audio_darts(audio_features)

        # Process text features
        text_repr = self.text_projection(text_features)

        # Fusion
        if self.fusion_method == 'concatenation':
            fused = torch.cat([audio_repr, text_repr], dim=1)
        elif self.fusion_method == 'tucker':
            fused = self.tucker_fusion(audio_repr, text_repr)
        elif self.fusion_method == 'mfb':
            fused = self.mfb_fusion(audio_repr, text_repr)
        elif self.fusion_method == 'block':
            fused = self.block_fusion(audio_repr, text_repr)
        else:
            fused = torch.cat([audio_repr, text_repr], dim=1)

        # Classification
        output = self.classifier(fused)

        return output

    def tucker_fusion(self, audio, text):
        """Tucker decomposition fusion"""
        # Simplified Tucker fusion
        outer_product = torch.bmm(audio.unsqueeze(2), text.unsqueeze(1))
        fused = torch.mean(outer_product, dim=[1, 2])
        return fused

    def mfb_fusion(self, audio, text):
        """Multimodal Factorized Bilinear pooling"""
        # Simplified MFB
        expanded_audio = audio.unsqueeze(2).expand(-1, -1, text.size(1))
        expanded_text = text.unsqueeze(1).expand(-1, audio.size(1), -1)
        fused = torch.sum(expanded_audio * expanded_text, dim=2)
        return fused

    def block_fusion(self, audio, text):
        """Block fusion"""
        # Element-wise multiplication and addition
        fused = audio * text + audio + text
        return fused

# Step 8: Dataset Class
class ADReSSoDataset(Dataset):
    def __init__(self, audio_files, labels, audio_extractor, text_processor, bert_processor):
        self.audio_files = audio_files
        self.labels = labels
        self.audio_extractor = audio_extractor
        self.text_processor = text_processor
        self.bert_processor = bert_processor

        # Pre-extract features to avoid repeated computation
        self.audio_features = []
        self.text_features = []

        print("Extracting features...")
        for audio_file in tqdm(audio_files):
            # Extract audio features
            audio_feat = self.audio_extractor.extract_acoustic_features(audio_file)
            self.audio_features.append(audio_feat)

            # Convert audio to text and extract BERT features
            text = self.text_processor.audio_to_text(audio_file)
            text_feat = self.bert_processor.encode_text(text)
            self.text_features.append(text_feat.squeeze(0))

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_feat = torch.FloatTensor(self.audio_features[idx])
        text_feat = self.text_features[idx]
        label = torch.LongTensor([self.labels[idx]])

        return audio_feat, text_feat, label

# Step 9: Training Function
def train_model(model, train_loader, val_loader, num_epochs=50, learning_rate=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    train_losses = []
    val_accuracies = []

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0

        for audio_feat, text_feat, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            audio_feat, text_feat, labels = audio_feat.to(device), text_feat.to(device), labels.to(device)
            labels = labels.squeeze()

            optimizer.zero_grad()
            outputs = model(audio_feat, text_feat)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for audio_feat, text_feat, labels in val_loader:
                audio_feat, text_feat, labels = audio_feat.to(device), text_feat.to(device), labels.to(device)
                labels = labels.squeeze()

                outputs = model(audio_feat, text_feat)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_accuracy = 100 * val_correct / val_total
        avg_train_loss = train_loss / len(train_loader)

        train_losses.append(avg_train_loss)
        val_accuracies.append(val_accuracy)

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        scheduler.step()

    return train_losses, val_accuracies

# Step 10: Main Execution Function
def main():
    print("Starting Multimodal AD Detection Model Training...")

    # Initialize components
    dataset_loader = ADReSSoDatasetLoader()
    audio_extractor = AudioFeatureExtractor()
    text_processor = SpeechToTextProcessor()
    bert_processor = BERTProcessor()

    # Load data
    print("Loading audio files...")
    audio_files, labels = dataset_loader.load_audio_files()

    if len(audio_files) == 0:
        print("No audio files found. Please check dataset paths.")
        return

    print(f"Found {len(audio_files)} audio files")

    # Split data
    train_files, val_files, train_labels, val_labels = train_test_split(
        audio_files, labels, test_size=0.2, random_state=42, stratify=labels
    )

    # Create datasets
    print("Creating datasets...")
    train_dataset = ADReSSoDataset(train_files, train_labels, audio_extractor, text_processor, bert_processor)
    val_dataset = ADReSSoDataset(val_files, val_labels, audio_extractor, text_processor, bert_processor)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

    # Get feature dimensions
    audio_dim = len(train_dataset.audio_features[0])
    text_dim = train_dataset.text_features[0].shape[0]

    print(f"Audio feature dimension: {audio_dim}")
    print(f"Text feature dimension: {text_dim}")

    # Initialize model
    model = MultimodalFusionModel(audio_dim=audio_dim, text_dim=text_dim)

    # Train model
    print("Starting training...")
    train_losses, val_accuracies = train_model(model, train_loader, val_loader)

    # Plot training curves
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies)
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')

    plt.tight_layout()
    plt.show()

    # Save model
    torch.save(model.state_dict(), '/content/drive/MyDrive/Voice/multimodal_ad_model.pth')
    print("Model saved successfully!")

    return model, train_loader, val_loader

# Usage
if __name__ == "__main__":
    # Install dependencies first
    install_dependencies()

    # Run main function
    model, train_loader, val_loader = main()

ModuleNotFoundError: No module named 'speech_recognition'