# üéØ Optimized CRCNN Model for Cough Classification

## Ki·∫øn tr√∫c CRCNN (Convolutional Recurrent Convolutional Neural Network)

M√¥ h√¨nh CRCNN k·∫øt h·ª£p s·ª©c m·∫°nh c·ªßa:
1. **Convolutional layers** - Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng kh√¥ng gian t·ª´ spectrogram
2. **Recurrent layers (GRU)** - M√¥ h√¨nh h√≥a chu·ªói th·ªùi gian
3. **Attention mechanism** - T·∫≠p trung v√†o c√°c ph·∫ßn quan tr·ªçng
4. **Residual connections** - C·∫£i thi·ªán gradient flow


In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.models as models
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from tqdm import tqdm
import json
import os
from datetime import datetime

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("WARNING: CUDA not available, using CPU")

print("\n‚úì Libraries imported successfully")

Using device: cuda
GPU: NVIDIA A100-SXM4-80GB
CUDA Version: 12.8
Available GPU Memory: 79.25 GB

‚úì Libraries imported successfully


In [18]:
# Data directory
DATA_DIR = '../processed_data'

# Load data
print("Loading data...")
X_train = np.load(f'{DATA_DIR}/X_train.npy')
y_train = np.load(f'{DATA_DIR}/y_train.npy')
X_val = np.load(f'{DATA_DIR}/X_val.npy')
y_val = np.load(f'{DATA_DIR}/y_val.npy')
X_test = np.load(f'{DATA_DIR}/X_test.npy')
y_test = np.load(f'{DATA_DIR}/y_test.npy')

# Load label mapping
with open(f'{DATA_DIR}/label_mapping.json', 'r') as f:
    label_info = json.load(f)
    label_to_idx = label_info['label_to_idx']
    idx_to_label = {int(k): v for k, v in label_info['idx_to_label'].items()}
    num_classes = label_info['num_classes']

print(f"‚úì Data loaded successfully!")
print(f"\nDataset shapes:")
print(f"  X_train: {X_train.shape}, y_train: {y_train.shape}")
print(f"  X_val:   {X_val.shape}, y_val: {y_val.shape}")
print(f"  X_test:  {X_test.shape}, y_test: {y_test.shape}")
print(f"\nNumber of classes: {num_classes}")
print(f"Class names: {list(label_to_idx.keys())}")

Loading data...
‚úì Data loaded successfully!

Dataset shapes:
  X_train: (44124, 1, 256, 126), y_train: (44124,)
  X_val:   (4903, 1, 256, 126), y_val: (4903,)
  X_test:  (4903, 1, 256, 126), y_test: (4903,)

Number of classes: 4
Class names: ['asthma', 'covid', 'healthy', 'tuberculosis']
‚úì Data loaded successfully!

Dataset shapes:
  X_train: (44124, 1, 256, 126), y_train: (44124,)
  X_val:   (4903, 1, 256, 126), y_val: (4903,)
  X_test:  (4903, 1, 256, 126), y_test: (4903,)

Number of classes: 4
Class names: ['asthma', 'covid', 'healthy', 'tuberculosis']


## üèóÔ∏è Optimized CRCNN Architecture

CRCNN model ƒë∆∞·ª£c thi·∫øt k·∫ø ƒë·∫∑c bi·ªát cho audio classification v·ªõi c√°c c·∫£i ti·∫øn:

### 1. Convolutional Blocks v·ªõi Residual Connections
- 3 conv blocks: [64, 128, 256] channels
- BatchNorm + ReLU + Dropout
- Skip connections ƒë·ªÉ tr√°nh vanishing gradients

### 2. Bidirectional GRU (thay v√¨ LSTM)
- Nhanh h∆°n, √≠t parameter h∆°n ‚Üí gi·∫£m overfitting
- 2 layers v·ªõi dropout
- M√¥ h√¨nh h√≥a sequence t·ª´ c·∫£ 2 h∆∞·ªõng

### 3. Attention Mechanism
- T·ª± ƒë·ªông focus v√†o c√°c ph·∫ßn quan tr·ªçng c·ªßa audio
- T·ªët h∆°n vi·ªác ch·ªâ d√πng output cu·ªëi c√πng

### 4. Strong Regularization
- Dropout 0.5-0.6 ·ªü nhi·ªÅu layers
- Weight decay trong optimizer
- Layer Normalization thay v√¨ ch·ªâ BatchNorm

In [None]:
class ImprovedCRCNN(nn.Module):
    """
    ‚≠ê IMPROVED CRCNN for Cough Classification ‚≠ê
    
    T·ªëi ∆∞u cho high accuracy v√† anti-overfitting:
    - Residual connections trong conv blocks
    - Bidirectional GRU (kh√¥ng ph·∫£i LSTM)
    - Multi-head attention mechanism
    - Strong regularization (dropout 0.5-0.6)
    - Layer normalization
    - Gradient checkpointing support
    """
    
    def __init__(self, input_channels=1, num_classes=4, 
                 conv_channels=[64, 128, 256], 
                 rnn_hidden_size=256, rnn_layers=2,
                 dropout=0.5, attention_heads=4):
        super(ImprovedCRCNN, self).__init__()
        
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.dropout = dropout
        
        # === CONVOLUTIONAL BLOCKS v·ªõi RESIDUAL CONNECTIONS ===
        self.conv_blocks = nn.ModuleList()
        self.residual_projs = nn.ModuleList()
        
        in_ch = input_channels
        for i, out_ch in enumerate(conv_channels):
            # Main conv path v·ªõi 2 conv layers
            block = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Dropout2d(dropout * 0.3),  # Dropout th·∫•p h∆°n cho conv
                
                nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
            )
            self.conv_blocks.append(block)
            
            # Projection cho residual n·∫øu channels thay ƒë·ªïi
            if in_ch != out_ch:
                self.residual_projs.append(
                    nn.Conv2d(in_ch, out_ch, 1, bias=False)
                )
            else:
                self.residual_projs.append(None)
            
            in_ch = out_ch
        
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout_2d = nn.Dropout2d(dropout * 0.4)
        
        # Adaptive pooling ƒë·ªÉ ƒë·∫£m b·∫£o k√≠ch th∆∞·ªõc c·ªë ƒë·ªãnh
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 32))  # (1, 32) cho temporal dim
        
        # === FEATURE PROJECTION tr∆∞·ªõc RNN ===
        self.feature_channels = conv_channels[-1]
        self.rnn_hidden_size = rnn_hidden_size
        
        self.pre_rnn_proj = nn.Sequential(
            nn.Linear(self.feature_channels, rnn_hidden_size),
            nn.LayerNorm(rnn_hidden_size),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout * 0.5)
        )
        
        # === BIDIRECTIONAL GRU ===
        self.gru = nn.GRU(
            input_size=rnn_hidden_size,
            hidden_size=rnn_hidden_size,
            num_layers=rnn_layers,
            batch_first=True,
            dropout=dropout * 0.5 if rnn_layers > 1 else 0,
            bidirectional=True
        )
        
        # === MULTI-HEAD ATTENTION ===
        self.attention_heads = attention_heads
        gru_output_size = rnn_hidden_size * 2  # Bidirectional
        
        # Multi-head attention
        self.attention_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(gru_output_size, 64),
                nn.Tanh(),
                nn.Dropout(dropout * 0.3),
                nn.Linear(64, 1)
            ) for _ in range(attention_heads)
        ])
        
        # === CLASSIFICATION HEAD ===
        classifier_input_size = gru_output_size * attention_heads
        
        self.classifier = nn.Sequential(
            # First block
            nn.Linear(classifier_input_size, 512),
            nn.LayerNorm(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            
            # Second block
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout * 0.6),  # Dropout cao h∆°n ·ªü cu·ªëi
            
            # Output
            nn.Linear(256, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Kh·ªüi t·∫°o weights t·ªët h∆°n"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.GRU):
                for name, param in m.named_parameters():
                    if 'weight_ih' in name:
                        nn.init.xavier_uniform_(param.data)
                    elif 'weight_hh' in name:
                        nn.init.orthogonal_(param.data)
                    elif 'bias' in name:
                        nn.init.constant_(param.data, 0)
        
    def forward(self, x):
        """
        Args:
            x: (batch, channels, height, width) - spectrogram
        Returns:
            output: (batch, num_classes) - logits
        """
        # === CONVOLUTIONAL FEATURE EXTRACTION ===
        for i, (conv_block, res_proj) in enumerate(zip(self.conv_blocks, self.residual_projs)):
            identity = x
            
            # Main path
            x = conv_block(x)
            
            # Residual connection
            if res_proj is not None:
                identity = res_proj(identity)
            
            x = x + identity
            x = self.relu(x)
            x = self.pool(x)
            x = self.dropout_2d(x)
        
        # === RESHAPE cho RNN ===
        # Adaptive pooling: (batch, channels, H, W) -> (batch, channels, 1, 32)
        x = self.adaptive_pool(x)
        
        # Reshape: (batch, channels, 1, time) -> (batch, time, channels)
        batch_size, channels, _, time_steps = x.size()
        x = x.squeeze(2).permute(0, 2, 1)  # (batch, time, channels)
        
        # Project features
        x = self.pre_rnn_proj(x)  # (batch, time, rnn_hidden_size)
        
        # === GRU TEMPORAL MODELING ===
        gru_out, _ = self.gru(x)  # (batch, time, rnn_hidden_size * 2)
        
        # === MULTI-HEAD ATTENTION ===
        attended_features = []
        for attention_layer in self.attention_layers:
            # Compute attention weights
            attention_scores = attention_layer(gru_out)  # (batch, time, 1)
            attention_weights = torch.softmax(attention_scores, dim=1)
            
            # Apply attention
            attended = torch.sum(gru_out * attention_weights, dim=1)  # (batch, hidden*2)
            attended_features.append(attended)
        
        # Concatenate all attention heads
        context = torch.cat(attended_features, dim=1)  # (batch, hidden*2*heads)
        
        # === CLASSIFICATION ===
        output = self.classifier(context)
        
        return output
    
    def get_model_size(self):
        """T√≠nh k√≠ch th∆∞·ªõc model"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        return {
            'total_params': total_params,
            'trainable_params': trainable_params,
            'size_mb': total_params * 4 / (1024 ** 2)  # Assuming float32
        }

print("‚úÖ ImprovedCRCNN model defined successfully!")
print("üì¶ Architecture: Conv(Residual) ‚Üí GRU(Bidirectional) ‚Üí Multi-Head Attention ‚Üí Classifier")
print("üõ°Ô∏è  Anti-overfitting: Dropout 0.5-0.6, BatchNorm, LayerNorm, Residual connections")

‚úì OPTIMIZED CRCNN model class defined successfully
  ‚Ä¢ Residual connections in conv blocks
  ‚Ä¢ GRU instead of LSTM (faster, less overfitting)
  ‚Ä¢ Attention mechanism for temporal modeling
  ‚Ä¢ Layer normalization for stability
  ‚Ä¢ Proper dimension handling
  ‚Ä¢ Weight initialization


In [20]:
# Data preprocessing and conversion to PyTorch tensors
def prepare_data_for_pytorch(X, y):
    """Convert numpy arrays to PyTorch tensors and handle data format"""
    
    # Convert to tensors
    X_tensor = torch.FloatTensor(X)
    y_tensor = torch.LongTensor(y)
    
    # Ensure X has the right shape: (batch, channels, height, width)
    if len(X_tensor.shape) == 3:
        # Add channel dimension: (batch, height, width) -> (batch, 1, height, width)
        X_tensor = X_tensor.unsqueeze(1)
    elif len(X_tensor.shape) == 4 and X_tensor.shape[1] != 1:
        # If channels are last: (batch, height, width, channels) -> (batch, channels, height, width)
        X_tensor = X_tensor.permute(0, 3, 1, 2)
    
    return X_tensor, y_tensor

# Prepare data
print("Preparing data for PyTorch...")
X_train_tensor, y_train_tensor = prepare_data_for_pytorch(X_train, y_train)
X_val_tensor, y_val_tensor = prepare_data_for_pytorch(X_val, y_val)
X_test_tensor, y_test_tensor = prepare_data_for_pytorch(X_test, y_test)

print(f"Data shapes after preprocessing:")
print(f"  X_train_tensor: {X_train_tensor.shape}")
print(f"  X_val_tensor: {X_val_tensor.shape}")
print(f"  X_test_tensor: {X_test_tensor.shape}")

# Create datasets and dataloaders
batch_size = 32

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"‚úì Data loaders created with batch size: {batch_size}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

Preparing data for PyTorch...
Data shapes after preprocessing:
  X_train_tensor: torch.Size([44124, 1, 256, 126])
  X_val_tensor: torch.Size([4903, 1, 256, 126])
  X_test_tensor: torch.Size([4903, 1, 256, 126])
‚úì Data loaders created with batch size: 32
  Train batches: 1379
  Validation batches: 154
  Test batches: 154


## ? Data Augmentation - SpecAugment

SpecAugment l√† k·ªπ thu·∫≠t augmentation m·∫°nh m·∫Ω cho audio:
- **Time Masking**: Che c√°c time steps ng·∫´u nhi√™n
- **Frequency Masking**: Che c√°c frequency bins ng·∫´u nhi√™n  
- **Gaussian Noise**: Th√™m nhi·ªÖu ƒë·ªÉ tƒÉng robustness

Gi√∫p model h·ªçc ƒë∆∞·ª£c features t·ªïng qu√°t h∆°n, kh√¥ng b·ªã overfit tr√™n training data.

In [None]:
class SpecAugment(nn.Module):
    """
    SpecAugment - Data Augmentation for Spectrograms
    
    Paper: SpecAugment: A Simple Data Augmentation Method for ASR
    https://arxiv.org/abs/1904.08779
    """
    def __init__(self, time_mask_param=30, freq_mask_param=15, 
                 num_time_masks=2, num_freq_masks=2, 
                 p=0.5, noise_std=0.05):
        super().__init__()
        self.time_mask_param = time_mask_param
        self.freq_mask_param = freq_mask_param
        self.num_time_masks = num_time_masks
        self.num_freq_masks = num_freq_masks
        self.p = p  # Probability of applying augmentation
        self.noise_std = noise_std
        
    def forward(self, x):
        """
        Args:
            x: (batch, channels, freq, time)
        """
        if not self.training:
            return x
        
        batch, channels, freq, time = x.shape
        x = x.clone()
        
        # Apply to each sample in batch v·ªõi probability p
        for i in range(batch):
            if torch.rand(1).item() < self.p:
                # Time masking
                for _ in range(self.num_time_masks):
                    if time > self.time_mask_param:
                        t = torch.randint(1, self.time_mask_param, (1,)).item()
                        t0 = torch.randint(0, time - t, (1,)).item()
                        x[i, :, :, t0:t0+t] = 0
                
                # Frequency masking
                for _ in range(self.num_freq_masks):
                    if freq > self.freq_mask_param:
                        f = torch.randint(1, self.freq_mask_param, (1,)).item()
                        f0 = torch.randint(0, freq - f, (1,)).item()
                        x[i, :, f0:f0+f, :] = 0
        
        # Add Gaussian noise (optional, lower probability)
        if torch.rand(1).item() < self.p * 0.5:
            noise = torch.randn_like(x) * self.noise_std
            x = x + noise
            
        return x

# Initialize augmentation
augmentation = SpecAugment(
    time_mask_param=30,      # Mask up to 30 time steps
    freq_mask_param=15,      # Mask up to 15 frequency bins
    num_time_masks=2,        # Apply 2 time masks
    num_freq_masks=2,        # Apply 2 freq masks
    p=0.5,                   # 50% probability per sample
    noise_std=0.05           # Small noise
).to(device)

print("‚úÖ SpecAugment configured:")
print(f"  ‚Ä¢ Time masking: up to {augmentation.time_mask_param} frames, {augmentation.num_time_masks} masks")
print(f"  ‚Ä¢ Freq masking: up to {augmentation.freq_mask_param} bins, {augmentation.num_freq_masks} masks")
print(f"  ‚Ä¢ Application probability: {augmentation.p * 100}%")
print(f"  ‚Ä¢ Gaussian noise: std={augmentation.noise_std}")
print("  ‚Ä¢ Only applied during training")

‚úì STRONG SpecAugment configured:
  ‚Ä¢ Time masking: up to 40 frames, 2 masks, 70% prob
  ‚Ä¢ Freq masking: up to 20 bins, 2 masks, 70% prob
  ‚Ä¢ Gaussian noise: std=0.1, 50% prob
  ‚Ä¢ Applied during training only


In [None]:
# ============================================
# INITIALIZE MODEL
# ============================================

print("="*80)
print("üöÄ INITIALIZING IMPROVED CRCNN MODEL")
print("="*80)

# Get input shape
input_channels = X_train_tensor.shape[1]  # Should be 1 for grayscale spectrograms

# Create model v·ªõi best hyperparameters
model = ImprovedCRCNN(
    input_channels=input_channels,
    num_classes=num_classes,
    conv_channels=[64, 128, 256],    # Progressive feature extraction
    rnn_hidden_size=256,             # GRU hidden size
    rnn_layers=2,                    # 2-layer bidirectional GRU
    dropout=0.5,                     # Strong dropout for anti-overfitting
    attention_heads=4                # Multi-head attention
).to(device)

# Model information
model_info = model.get_model_size()
print(f"\nüìä Model Information:")
print(f"  Total parameters: {model_info['total_params']:,}")
print(f"  Trainable parameters: {model_info['trainable_params']:,}")
print(f"  Model size: {model_info['size_mb']:.2f} MB")
print(f"  Input channels: {input_channels}")
print(f"  Output classes: {num_classes}")

# Test forward pass
print(f"\nüß™ Testing model forward pass...")
model.eval()
with torch.no_grad():
    sample_input = X_train_tensor[:2].to(device)
    sample_output = model(sample_input)
    
print(f"  Input shape: {sample_input.shape}")
print(f"  Output shape: {sample_output.shape}")
print(f"  Output range: [{sample_output.min():.3f}, {sample_output.max():.3f}]")
print(f"‚úÖ Model initialized successfully!")

model.train()
print("="*80)

üöÄ INITIALIZING OPTIMIZED CRCNN MODEL

üìä CRCNN Model Information:
  Architecture: Conv ‚Üí GRU ‚Üí Attention ‚Üí Classifier
  Total parameters: 3,687,045
  Trainable parameters: 3,687,045
  Input channels: 1
  Number of classes: 4
  RNN hidden size: 256
  RNN layers: 2
  Attention mechanism: True

üß™ Testing model with sample input...
  Sample input shape: torch.Size([2, 1, 256, 126])
  Sample output shape: torch.Size([2, 4])
  Output logits range: [-2.406, 0.018]
‚úì Model forward pass successful!


## üéì Training Loop v·ªõi Early Stopping

Training loop ƒë∆∞·ª£c t·ªëi ∆∞u ƒë·ªÉ:
- Track c·∫£ train v√† validation metrics
- Early stopping khi validation accuracy kh√¥ng c·∫£i thi·ªán (patience=10)
- Save best model checkpoint
- Visualize training progress

In [None]:
# ============================================
# TRAINING CONFIGURATION - OPTIMIZED FOR HIGH ACCURACY & ANTI-OVERFITTING
# ============================================

from sklearn.utils.class_weight import compute_class_weight

# Compute class weights for imbalanced data
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights_tensor = torch.FloatTensor(class_weights).to(device)

print("üìä Class Distribution:")
for i, (label, weight) in enumerate(zip(idx_to_label.values(), class_weights)):
    count = np.sum(y_train == i)
    print(f"  Class {i} ({label}): {count} samples, weight={weight:.3f}")

# Loss function v·ªõi class weights v√† label smoothing
criterion = nn.CrossEntropyLoss(
    weight=class_weights_tensor, 
    label_smoothing=0.1  # Label smoothing to prevent overconfidence
)

# Optimizer: AdamW v·ªõi weight decay
optimizer = optim.AdamW(
    model.parameters(), 
    lr=0.001,              # Initial learning rate
    weight_decay=0.01,     # Weight decay for L2 regularization
    betas=(0.9, 0.999)
)

# Learning rate scheduler: ReduceLROnPlateau
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',           # Maximize validation accuracy
    factor=0.5,           # Reduce LR by half
    patience=5,           # Wait 5 epochs before reducing
    verbose=True,
    min_lr=1e-6
)

print("\n‚öôÔ∏è  TRAINING CONFIGURATION:")
print(f"  Loss: CrossEntropyLoss (class-weighted, label_smoothing=0.1)")
print(f"  Optimizer: AdamW (lr=0.001, weight_decay=0.01)")
print(f"  Scheduler: ReduceLROnPlateau (factor=0.5, patience=5)")
print(f"  Batch size: {batch_size}")
print(f"  Gradient clipping: 1.0")
print(f"\nüõ°Ô∏è  ANTI-OVERFITTING STRATEGIES:")
print(f"  ‚úÖ Dropout: 0.5-0.6 in model")
print(f"  ‚úÖ Weight Decay: 0.01")
print(f"  ‚úÖ Label Smoothing: 0.1")
print(f"  ‚úÖ SpecAugment: Time & Freq masking")
print(f"  ‚úÖ Class-balanced loss")
print(f"  ‚úÖ Early stopping (patience=10)")
print(f"  ‚úÖ Gradient clipping")


# ============================================
# TRAINING & VALIDATION FUNCTIONS
# ============================================

def train_epoch(model, train_loader, criterion, optimizer, device, augmentation=None):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc="Training", leave=False)
    for inputs, targets in train_bar:
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        
        # Apply augmentation
        if augmentation is not None:
            inputs = augmentation(inputs)
        
        # Forward pass
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass v·ªõi gradient clipping
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        train_bar.set_postfix({
            'Loss': f'{running_loss/(train_bar.n+1):.3f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        val_bar = tqdm(val_loader, desc="Validation", leave=False)
        for inputs, targets in val_bar:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Save for metrics
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            
            # Update progress bar
            val_bar.set_postfix({
                'Loss': f'{val_loss/(val_bar.n+1):.3f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = val_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_preds, all_targets

print("\n‚úÖ Training functions defined successfully!")

Class weights for imbalanced data:
  asthma: 1.328
  covid: 0.750
  healthy: 1.724
  tuberculosis: 0.750

üîß ANTI-OVERFITTING CONFIGURATION:
  ‚úì Label Smoothing: 0.15 (increased from 0.1)
  ‚úì Learning Rate: 0.0005 (reduced from 0.001)
  ‚úì Weight Decay: 0.02 (increased from 0.01)
  ‚úì Scheduler: CosineAnnealingWarmRestarts (better than OneCycleLR for overfitting)
  ‚úì Dropout: 0.5 (already in model)
  ‚úì Class-balanced loss with weights
  ‚úì Mixed Precision: DISABLED (for stability)

‚úì SIMPLIFIED & ROBUST training functions defined successfully
  ‚Ä¢ NO GradScaler errors (mixed precision disabled)
  ‚Ä¢ Gradient clipping (max_norm=1.0) - ALWAYS works
  ‚Ä¢ CosineAnnealingWarmRestarts scheduler
  ‚Ä¢ Class-balanced loss with label smoothing (0.15)
  ‚Ä¢ Lower LR (0.0005) + Higher weight decay (0.02)
  ‚Ä¢ Non-blocking data transfer


In [None]:
# ============================================
# TRAINING LOOP v·ªõi Early Stopping
# ============================================

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                num_epochs=80, early_stopping_patience=10, 
                save_path='best_crcnn_cough.pth'):
    """
    Complete training loop v·ªõi monitoring v√† early stopping
    """
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'learning_rates': []
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    
    print("\n" + "="*80)
    print("üéì STARTING TRAINING")
    print("="*80)
    print(f"Device: {device}")
    print(f"Total epochs: {num_epochs}")
    print(f"Early stopping patience: {early_stopping_patience}")
    print(f"Model checkpoints will be saved to: {save_path}")
    print("="*80 + "\n")
    
    for epoch in range(num_epochs):
        epoch_start_time = datetime.now()
        
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print("-" * 40)
        
        # Training phase
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, 
            augmentation=augmentation
        )
        
        # Validation phase
        val_loss, val_acc, val_preds, val_targets = validate_epoch(
            model, val_loader, criterion, device
        )
        
        # Update learning rate
        scheduler.step(val_acc)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['learning_rates'].append(current_lr)
        
        # Print epoch summary
        epoch_time = (datetime.now() - epoch_start_time).total_seconds()
        train_val_gap = train_acc - val_acc
        
        print(f"\nüìä Results:")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
        print(f"  Train-Val Gap: {train_val_gap:+.2f}%")
        print(f"  Learning Rate: {current_lr:.2e}")
        print(f"  Epoch Time: {epoch_time:.1f}s")
        
        # Check for overfitting warning
        if train_val_gap > 10:
            print(f"  ‚ö†Ô∏è  Warning: Possible overfitting (gap > 10%)")
        
        # Save best model
        if val_acc > best_val_acc:
            improvement = val_acc - best_val_acc
            best_val_acc = val_acc
            
            # Save checkpoint
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_acc': best_val_acc,
                'history': history,
                'class_names': list(idx_to_label.values())
            }
            torch.save(checkpoint, save_path)
            
            print(f"  ‚úÖ New best model saved! (+{improvement:.2f}% improvement)")
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{early_stopping_patience})")
        
        # Early stopping
        if patience_counter >= early_stopping_patience:
            print(f"\n{'='*80}")
            print(f"‚èπÔ∏è  Early stopping triggered!")
            print(f"   Best validation accuracy: {best_val_acc:.2f}%")
            print(f"   Stopping at epoch {epoch+1}")
            print("="*80)
            break
        
        print("")  # Empty line for readability
    
    print("\n" + "="*80)
    print("‚úÖ TRAINING COMPLETED!")
    print("="*80)
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Total epochs trained: {len(history['train_loss'])}")
    print(f"Model saved to: {save_path}")
    print("="*80 + "\n")
    
    return history, best_val_acc

print("‚úÖ Training loop function defined!")

‚úì OPTIMIZED training loop function defined successfully


## üöÄ Start Training

B√¢y gi·ªù ch√∫ng ta s·∫Ω b·∫Øt ƒë·∫ßu train model v·ªõi t·∫•t c·∫£ c√°c optimization ƒë√£ setup.

In [None]:
# ============================================
# START TRAINING
# ============================================

# Training hyperparameters
NUM_EPOCHS = 80
EARLY_STOPPING_PATIENCE = 10
MODEL_SAVE_PATH = 'best_crcnn_cough_model.pth'

# Start training
history, best_val_acc = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=NUM_EPOCHS,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    save_path=MODEL_SAVE_PATH
)

print(f"\nüéâ Training finished!")
print(f"üìà Best validation accuracy achieved: {best_val_acc:.2f}%")

## üìä Visualize Training History

Visualize training v√† validation metrics ƒë·ªÉ ki·ªÉm tra overfitting.

In [None]:
# ============================================
# VISUALIZE TRAINING HISTORY
# ============================================

def plot_training_history(history):
    """Plot training history"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss plot
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch', fontsize=12)
    axes[0, 0].set_ylabel('Loss', fontsize=12)
    axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    axes[0, 0].legend(fontsize=11)
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy plot
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
    axes[0, 1].set_xlabel('Epoch', fontsize=12)
    axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
    axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].legend(fontsize=11)
    axes[0, 1].grid(True, alpha=0.3)
    
    # Learning rate plot
    axes[1, 0].plot(epochs, history['learning_rates'], 'g-', linewidth=2)
    axes[1, 0].set_xlabel('Epoch', fontsize=12)
    axes[1, 0].set_ylabel('Learning Rate', fontsize=12)
    axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Train-Val gap (overfitting indicator)
    gap = [t - v for t, v in zip(history['train_acc'], history['val_acc'])]
    axes[1, 1].plot(epochs, gap, 'purple', linewidth=2)
    axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
    axes[1, 1].axhline(y=10, color='r', linestyle='--', alpha=0.3, label='Overfitting threshold')
    axes[1, 1].fill_between(epochs, 0, gap, alpha=0.3, color='purple')
    axes[1, 1].set_xlabel('Epoch', fontsize=12)
    axes[1, 1].set_ylabel('Train - Val Accuracy (%)', fontsize=12)
    axes[1, 1].set_title('Overfitting Indicator (Train-Val Gap)', fontsize=14, fontweight='bold')
    axes[1, 1].legend(fontsize=11)
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary
    print("\n" + "="*60)
    print("üìä TRAINING SUMMARY")
    print("="*60)
    print(f"Total epochs: {len(epochs)}")
    print(f"Best train accuracy: {max(history['train_acc']):.2f}%")
    print(f"Best validation accuracy: {max(history['val_acc']):.2f}%")
    print(f"Final train-val gap: {gap[-1]:.2f}%")
    print(f"Average train-val gap: {np.mean(gap):.2f}%")
    
    if max(gap) > 10:
        print(f"\n‚ö†Ô∏è  Warning: Max train-val gap was {max(gap):.2f}% (overfitting detected)")
    else:
        print(f"\n‚úÖ Model shows good generalization (max gap: {max(gap):.2f}%)")
    print("="*60 + "\n")

# Plot the history
plot_training_history(history)

## üß™ Evaluate on Test Set

Load best model v√† evaluate tr√™n test set ƒë·ªÉ c√≥ k·∫øt qu·∫£ cu·ªëi c√πng.

In [None]:
# ============================================
# EVALUATE ON TEST SET
# ============================================

def evaluate_model(model, test_loader, device, model_path=None):
    """Evaluate model on test set"""
    
    # Load best model if path provided
    if model_path and os.path.exists(model_path):
        print(f"Loading best model from {model_path}...")
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"‚úÖ Model loaded (Best val acc: {checkpoint['best_val_acc']:.2f}%)\n")
    
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []
    
    print("Evaluating on test set...")
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Testing"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Calculate metrics
    test_acc = accuracy_score(all_targets, all_preds) * 100
    
    print("\n" + "="*60)
    print("üéØ TEST SET RESULTS")
    print("="*60)
    print(f"Test Accuracy: {test_acc:.2f}%")
    print("="*60 + "\n")
    
    # Classification report
    print("üìã Classification Report:")
    print("-" * 60)
    class_names = list(idx_to_label.values())
    report = classification_report(all_targets, all_preds, 
                                   target_names=class_names,
                                   digits=4)
    print(report)
    
    return all_preds, all_targets, all_probs, test_acc

# Evaluate
test_preds, test_targets, test_probs, test_accuracy = evaluate_model(
    model, test_loader, device, model_path=MODEL_SAVE_PATH
)

## üìà Confusion Matrix v√† Detailed Metrics

Visualize confusion matrix v√† per-class metrics.

In [None]:
# ============================================
# CONFUSION MATRIX & VISUALIZATION
# ============================================

def plot_confusion_matrix(y_true, y_pred, class_names):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Confusion matrix (counts)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'}, ax=axes[0], 
                annot_kws={'size': 12})
    axes[0].set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('True Label', fontsize=12, fontweight='bold')
    axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
    
    # Confusion matrix (normalized)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Greens',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Percentage'}, ax=axes[1],
                annot_kws={'size': 12})
    axes[1].set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
    axes[1].set_ylabel('True Label', fontsize=12, fontweight='bold')
    axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print per-class accuracy
    print("\nüìä Per-Class Accuracy:")
    print("-" * 50)
    for i, class_name in enumerate(class_names):
        class_acc = cm[i, i] / cm[i].sum() * 100
        print(f"  {class_name:20s}: {class_acc:6.2f}% ({cm[i, i]}/{cm[i].sum()})")
    print("-" * 50)

# Plot confusion matrix
class_names = list(idx_to_label.values())
plot_confusion_matrix(test_targets, test_preds, class_names)

## üíæ Model Summary

T·ªïng k·∫øt th√¥ng tin v·ªÅ model ƒë√£ train.

In [None]:
# ============================================
# FINAL MODEL SUMMARY
# ============================================

print("\n" + "="*80)
print("üéâ MODEL TRAINING & EVALUATION COMPLETED!")
print("="*80)

# Model info
model_info = model.get_model_size()
print(f"\nüì¶ Model Architecture: ImprovedCRCNN")
print(f"   ‚Ä¢ Total parameters: {model_info['total_params']:,}")
print(f"   ‚Ä¢ Trainable parameters: {model_info['trainable_params']:,}")
print(f"   ‚Ä¢ Model size: {model_info['size_mb']:.2f} MB")

print(f"\nüèóÔ∏è  Architecture Components:")
print(f"   ‚Ä¢ Conv blocks: [64, 128, 256] channels with residual connections")
print(f"   ‚Ä¢ Bidirectional GRU: 2 layers, hidden_size=256")
print(f"   ‚Ä¢ Multi-head attention: 4 heads")
print(f"   ‚Ä¢ Classifier: 512 ‚Üí 256 ‚Üí {num_classes} classes")

print(f"\nüõ°Ô∏è  Anti-Overfitting Techniques:")
print(f"   ‚úÖ Dropout: 0.5-0.6")
print(f"   ‚úÖ SpecAugment: Time & Frequency masking")
print(f"   ‚úÖ Label Smoothing: 0.1")
print(f"   ‚úÖ Weight Decay: 0.01 (AdamW)")
print(f"   ‚úÖ Batch Normalization + Layer Normalization")
print(f"   ‚úÖ Gradient Clipping: 1.0")
print(f"   ‚úÖ Class-balanced loss")
print(f"   ‚úÖ Early Stopping: patience=10")

print(f"\nüìä Performance:")
print(f"   ‚Ä¢ Best Validation Accuracy: {best_val_acc:.2f}%")
print(f"   ‚Ä¢ Test Accuracy: {test_accuracy:.2f}%")
print(f"   ‚Ä¢ Total epochs trained: {len(history['train_loss'])}")

print(f"\nüíæ Saved Files:")
print(f"   ‚Ä¢ Model checkpoint: {MODEL_SAVE_PATH}")
print(f"   ‚Ä¢ Training history plot: training_history.png")
print(f"   ‚Ä¢ Confusion matrix: confusion_matrix.png")

print("\n" + "="*80)
print("‚úÖ All done! Model is ready for deployment.")
print("="*80 + "\n")

## üöÄ How to Use This Model for Inference

ƒê·ªÉ s·ª≠ d·ª•ng model n√†y cho inference tr√™n audio m·ªõi:

In [None]:
# ============================================
# INFERENCE FUNCTION
# ============================================

def predict_cough(audio_spectrogram, model, device, class_names):
    """
    Predict cough type from audio spectrogram
    
    Args:
        audio_spectrogram: numpy array of shape (height, width) or (1, height, width)
        model: trained CRCNN model
        device: torch device
        class_names: list of class names
        
    Returns:
        predicted_class: predicted class name
        probabilities: dict of class probabilities
    """
    model.eval()
    
    # Prepare input
    if len(audio_spectrogram.shape) == 2:
        audio_spectrogram = audio_spectrogram[np.newaxis, np.newaxis, :, :]  # Add batch and channel dims
    elif len(audio_spectrogram.shape) == 3:
        audio_spectrogram = audio_spectrogram[np.newaxis, :, :, :]  # Add batch dim
    
    # Convert to tensor
    input_tensor = torch.FloatTensor(audio_spectrogram).to(device)
    
    # Predict
    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.softmax(outputs, dim=1)
        predicted_idx = torch.argmax(probs, dim=1).item()
        
    # Get results
    predicted_class = class_names[predicted_idx]
    probabilities = {class_names[i]: probs[0, i].item() for i in range(len(class_names))}
    
    return predicted_class, probabilities


# Example usage
print("="*60)
print("üîÆ INFERENCE EXAMPLE")
print("="*60)

# Load best model
checkpoint = torch.load(MODEL_SAVE_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Test with a random sample from test set
sample_idx = np.random.randint(0, len(X_test_tensor))
sample_spectrogram = X_test_tensor[sample_idx].cpu().numpy()
true_label = idx_to_label[y_test_tensor[sample_idx].item()]

# Predict
predicted_class, probabilities = predict_cough(
    sample_spectrogram, model, device, list(idx_to_label.values())
)

print(f"\nSample #{sample_idx}")
print(f"True label: {true_label}")
print(f"Predicted: {predicted_class}")
print(f"\nProbabilities:")
for class_name, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
    print(f"  {class_name:20s}: {prob*100:6.2f}%")

print("\n" + "="*60)
print("‚úÖ Inference function ready to use!")
print("="*60)

## üìù Key Takeaways & Best Practices

### üéØ ƒêi·ªÉm m·∫°nh c·ªßa CRCNN model n√†y:

1. **Ki·∫øn tr√∫c Hybrid**: K·∫øt h·ª£p CNN (spatial features) + GRU (temporal features) + Attention
2. **Residual Connections**: Gi√∫p training s√¢u h∆°n, gradient flow t·ªët h∆°n
3. **Multi-head Attention**: Focus v√†o c√°c ph·∫ßn quan tr·ªçng c·ªßa audio
4. **Strong Regularization**: Dropout cao, Weight Decay, Label Smoothing

### üõ°Ô∏è C√°c k·ªπ thu·∫≠t ch·ªëng overfitting ƒë√£ √°p d·ª•ng:

- ‚úÖ **Dropout 0.5-0.6** ·ªü nhi·ªÅu layers
- ‚úÖ **SpecAugment** cho audio augmentation
- ‚úÖ **Label Smoothing 0.1** ƒë·ªÉ tr√°nh overconfidence
- ‚úÖ **Weight Decay 0.01** trong AdamW optimizer
- ‚úÖ **Early Stopping** v·ªõi patience=10
- ‚úÖ **Gradient Clipping** ƒë·ªÉ ·ªïn ƒë·ªãnh training
- ‚úÖ **Class-balanced loss** cho imbalanced data
- ‚úÖ **Batch + Layer Normalization** k·∫øt h·ª£p

### üìä Expected Performance:

- **Training Accuracy**: 85-95%
- **Validation Accuracy**: 80-90%
- **Test Accuracy**: 80-90%
- **Train-Val Gap**: < 10% (good generalization)

### üí° Tips ƒë·ªÉ c·∫£i thi·ªán th√™m:

1. **Thu th·∫≠p th√™m data**: C√†ng nhi·ªÅu data, model c√†ng t·ªët
2. **Data augmentation**: Th√™m pitch shifting, time stretching
3. **Ensemble models**: K·∫øt h·ª£p nhi·ªÅu models kh√°c nhau
4. **Hyperparameter tuning**: Grid search cho best params
5. **Transfer learning**: Pre-train tr√™n dataset l·ªõn h∆°n

### üöÄ Production Deployment:

```python
# 1. Save model cho production
torch.save(model.state_dict(), 'crcnn_cough_production.pth')

# 2. Load v√† inference
model = ImprovedCRCNN(...)
model.load_state_dict(torch.load('crcnn_cough_production.pth'))
model.eval()

# 3. Convert sang ONNX cho faster inference (optional)
# torch.onnx.export(model, dummy_input, 'crcnn_cough.onnx')
```

---
**Created by**: CRCNN Optimization Team  
**Last Updated**: 2024  
**License**: MIT