In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import random
from tqdm import tqdm
from copy import deepcopy

In [None]:
# Set device and seeds for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

print(f"Using device: {device}")

In [None]:
from torchaudio.datasets import SPEECHCOMMANDS

class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f:
                return [os.path.join(self._path, line.strip()) for line in f]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]
            
        print(f"Loaded {subset} set with {len(self._walker)} samples")

In [None]:
# Load datasets
train_set = SubsetSC("training")
val_set = SubsetSC("validation")
test_set = SubsetSC("testing")

# Get list of labels
labels = sorted(list(set(item[2] for item in train_set)))
label_to_index = {label: i for i, label in enumerate(labels)}
index_to_label = {i: label for i, label in enumerate(labels)}

print(f"Number of classes: {len(labels)}")
print(f"Labels: {labels}")

In [None]:
# Audio preprocessing
class AudioPreprocessor:
    def __init__(self, sample_rate=16000, n_mfcc=40, n_fft=400, hop_length=160):
        self.sample_rate = sample_rate
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # MFCC transform
        self.mfcc_transform = T.MFCC(
            sample_rate=sample_rate,
            n_mfcc=n_mfcc,
            melkwargs={
                'n_fft': n_fft,
                'hop_length': hop_length,
                'n_mels': 80,
                'center': False
            }
        )
        
    def __call__(self, waveform, sample_rate):
        # Resample if needed
        if sample_rate != self.sample_rate:
            resampler = T.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
            waveform = resampler(waveform)
            
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
            
        # Pad or trim to 1 second (16000 samples)
        if waveform.shape[1] < self.sample_rate:
            waveform = F.pad(waveform, (0, self.sample_rate - waveform.shape[1]))
        else:
            waveform = waveform[:, :self.sample_rate]
            
        # Extract MFCC features
        mfcc = self.mfcc_transform(waveform)
        
        # Check for NaN or Inf values
        if torch.isnan(mfcc).any() or torch.isinf(mfcc).any():
            mfcc = torch.nan_to_num(mfcc, nan=0.0, posinf=1.0, neginf=-1.0)
            
        return mfcc

In [None]:
# Data augmentation
class AudioAugmentation:
    def __init__(self):
        self.time_mask = T.TimeMasking(time_mask_param=10)
        self.freq_mask = T.FrequencyMasking(freq_mask_param=10)
        
    def augment(self, x, strength=1.0):
        # Apply augmentation based on strength
        if strength > 0.5:
            x = self.freq_mask(x)
        if strength > 0.7:
            x = self.time_mask(x)
        if strength > 0.3:
            # Add some noise
            noise = torch.randn_like(x) * (0.1 * strength)
            x = x + noise
        return x

In [None]:
# Custom dataset for semi-supervised learning
class SpeechCommandDataset(Dataset):
    def __init__(self, dataset, preprocessor, label_to_index):
        self.dataset = dataset
        self.preprocessor = preprocessor
        self.label_to_index = label_to_index
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        waveform, sample_rate, label, _, _ = self.dataset[idx]
        
        # Preprocess audio
        features = self.preprocessor(waveform, sample_rate)
        
        # Convert label to index
        label_idx = self.label_to_index[label]
        
        return features, label_idx

In [None]:
# Create preprocessor
preprocessor = AudioPreprocessor()

# Create datasets with preprocessing
train_dataset = SpeechCommandDataset(train_set, preprocessor, label_to_index)
val_dataset = SpeechCommandDataset(val_set, preprocessor, label_to_index)
test_dataset = SpeechCommandDataset(test_set, preprocessor, label_to_index)

# Split training set into labeled and unlabeled
num_labeled = 1000  # Number of labeled examples to use
total_samples = len(train_dataset)

# Create indices for the split
indices = list(range(total_samples))
random.shuffle(indices)
labeled_indices = indices[:num_labeled]
unlabeled_indices = indices[num_labeled:]

# Create labeled and unlabeled datasets
labeled_dataset = Subset(train_dataset, labeled_indices)
unlabeled_dataset = Subset(train_dataset, unlabeled_indices)

print(f"Labeled samples: {len(labeled_dataset)}")
print(f"Unlabeled samples: {len(unlabeled_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Create data loaders
batch_size = 64
labeled_batch_size = 16  # Per batch

labeled_loader = DataLoader(
    labeled_dataset,
    batch_size=labeled_batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

unlabeled_loader = DataLoader(
    unlabeled_dataset,
    batch_size=batch_size - labeled_batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# Check a batch from each loader to verify shapes
for name, loader in [("Labeled", labeled_loader), ("Unlabeled", unlabeled_loader), ("Val", val_loader)]:
    inputs, labels = next(iter(loader))
    print(f"{name} batch - inputs: {inputs.shape}, labels: {labels.shape}")

In [None]:
# Define the CNN model
class AudioCNN(nn.Module):
    def __init__(self, num_classes):
        super(AudioCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Pooling and dropout
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.3)
        
        # Fully connected layers
        # Note: The exact size will depend on your input dimensions
        # We'll calculate this dynamically in the forward pass
        self.fc1 = nn.Linear(128 * 5 * 5, 512)  # This will be adjusted in forward pass
        self.fc2 = nn.Linear(512, num_classes)
        
        self.fc_input_size = None  # Will be set in first forward pass
        
    def forward(self, x):
        # Input shape: [batch, channels, height, width]
        # For MFCC: [batch, 1, n_mfcc, time]
        
        # Convolutional layers
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        
        # Dynamically compute the flattened size on first run
        if self.fc_input_size is None:
            self.fc_input_size = x.shape[1] * x.shape[2] * x.shape[3]
            # Recreate fc1 with the correct input size
            self.fc1 = nn.Linear(self.fc_input_size, 512).to(x.device)
            print(f"Set fc_input_size to {self.fc_input_size}")
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        
        return x

In [None]:
# Initialize models
student_model = AudioCNN(num_classes=len(labels)).to(device)
teacher_model = deepcopy(student_model)

# Set teacher model to evaluation mode
teacher_model.eval()

In [None]:
# Safe Student training function
class SafeStudent:
    def __init__(self, student_model, teacher_model, num_classes, 
                 alpha=0.99, temperature=0.5, lambda_consistency=1.0):
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.num_classes = num_classes
        self.alpha = alpha  # EMA decay rate
        self.temperature = temperature  # Temperature for soft targets
        self.lambda_consistency = lambda_consistency  # Weight for consistency loss
        
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(student_model.parameters(), lr=0.001, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100)
        self.augmentation = AudioAugmentation()
        
    def update_teacher(self):
        # Update teacher model using exponential moving average
        for teacher_param, student_param in zip(self.teacher_model.parameters(), 
                                               self.student_model.parameters()):
            teacher_param.data.mul_(self.alpha).add_(student_param.data, alpha=1 - self.alpha)
    
    def consistency_loss(self, logits_s, logits_t):
        # KL divergence between student and teacher predictions
        p_s = F.log_softmax(logits_s / self.temperature, dim=1)
        p_t = F.softmax(logits_t / self.temperature, dim=1)
        return F.kl_div(p_s, p_t, reduction='batchmean') * (self.temperature ** 2)
    
    def train_epoch(self, labeled_loader, unlabeled_loader, epoch):
        self.student_model.train()
        self.teacher_model.eval()  # Teacher always in eval mode
        
        total_loss = 0.0
        total_sup_loss = 0.0
        total_unsup_loss = 0.0
        correct = 0
        total = 0
        
        # Create iterators for the loaders
        labeled_iter = iter(labeled_loader)
        unlabeled_iter = iter(unlabeled_loader)
        
        # Determine number of batches
        num_batches = min(len(labeled_loader), len(unlabeled_loader))
        
        # Progress bar
        pbar = tqdm(range(num_batches), desc=f"Epoch {epoch+1}")
        
        for batch_idx in pbar:
            # Get labeled batch
            try:
                inputs_x, targets_x = next(labeled_iter)
            except StopIteration:
                labeled_iter = iter(labeled_loader)
                inputs_x, targets_x = next(labeled_iter)
            
            # Get unlabeled batch
            try:
                inputs_u, _ = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = iter(unlabeled_loader)
                inputs_u, _ = next(unlabeled_iter)
            
            # Move to device
            inputs_x, targets_x = inputs_x.to(device), targets_x.to(device)
            inputs_u = inputs_u.to(device)
            
            # Apply augmentation to unlabeled data
            inputs_u_aug = torch.stack([self.augmentation.augment(x.unsqueeze(0), strength=0.8).squeeze(0) 
                                       for x in inputs_u])
            
            # Forward pass through student model
            logits_x = self.student_model(inputs_x)  # Labeled data
            logits_u_aug = self.student_model(inputs_u_aug)  # Augmented unlabeled data
            
            # Supervised loss
            loss_x = self.criterion(logits_x, targets_x)
            
            # Get teacher predictions for unlabeled data (without augmentation)
            with torch.no_grad():
                logits_u_teacher = self.teacher_model(inputs_u)
            
            # Consistency loss between student and teacher
            loss_u = self.consistency_loss(logits_u_aug, logits_u_teacher)
            
            # Total loss
            loss = loss_x + self.lambda_consistency * loss_u
            
            # Check for NaN loss
            if torch.isnan(loss):
                print("NaN loss detected!")
                print(f"loss_x: {loss_x}, loss_u: {loss_u}")
                # Skip this batch
                continue
            
            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(self.student_model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # Update teacher model with EMA
            self.update_teacher()
            
            # Update metrics
            total_loss += loss.item()
            total_sup_loss += loss_x.item()
            total_unsup_loss += loss_u.item()
            
            # Calculate accuracy for supervised data
            _, predicted = torch.max(logits_x, 1)
            total += targets_x.size(0)
            correct += (predicted == targets_x).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'sup_loss': f"{loss_x.item():.4f}",
                'unsup_loss': f"{loss_u.item():.4f}",
                'acc': f"{100 * correct / total:.2f}%"
            })
        
        # Update learning rate
        self.scheduler.step()
        
        avg_loss = total_loss / num_batches
        avg_sup_loss = total_sup_loss / num_batches
        avg_unsup_loss = total_unsup_loss / num_batches
        accuracy = 100 * correct / total
        
        return avg_loss, avg_sup_loss, avg_unsup_loss, accuracy
    
    def evaluate(self, test_loader, use_teacher=True):
        # Use teacher model for evaluation (usually better)
        model = self.teacher_model if use_teacher else self.student_model
        model.eval()
        
        all_preds = []
        all_targets = []
        total_loss = 0.0
        
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Evaluating"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = self.criterion(outputs, targets)
                total_loss += loss.item()
                
                _, preds = torch.max(outputs, 1)
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        avg_loss = total_loss / len(test_loader)
        accuracy = accuracy_score(all_targets, all_preds)
        
        return avg_loss, accuracy
    
    def train(self, labeled_loader, unlabeled_loader, val_loader, num_epochs=100):
        best_accuracy = 0.0
        train_losses = []
        val_accuracies = []
        
        for epoch in range(num_epochs):
            # Train for one epoch
            train_loss, sup_loss, unsup_loss, train_acc = self.train_epoch(
                labeled_loader, unlabeled_loader, epoch
            )
            train_losses.append(train_loss)
            
            # Evaluate on validation set
            val_loss, val_accuracy = self.evaluate(val_loader)
            val_accuracies.append(val_accuracy)
            
            print(f"Epoch {epoch+1}/{num_epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Sup Loss: {sup_loss:.4f} | "
                  f"Unsup Loss: {unsup_loss:.4f} | "
                  f"Train Acc: {train_acc:.2f}% | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val Acc: {val_accuracy*100:.2f}%")
            
            # Save best model
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                torch.save({
                    'epoch': epoch,
                    'student_state_dict': self.student_model.state_dict(),
                    'teacher_state_dict': self.teacher_model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'accuracy': val_accuracy,
                }, 'best_model_safestudent.pth')
                print(f"New best model saved with accuracy: {val_accuracy*100:.2f}%")
        
        # Plot training progress
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(train_losses)
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        
        plt.subplot(1, 2, 2)
        plt.plot(val_accuracies)
        plt.title('Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        
        plt.tight_layout()
        plt.savefig('training_progress.png')
        plt.show()
        
        return best_accuracy, train_losses, val_accuracies
    
    def test(self, test_loader):
        # Load best model
        checkpoint = torch.load('best_model_safestudent.pth')
        self.student_model.load_state_dict(checkpoint['student_state_dict'])
        self.teacher_model.load_state_dict(checkpoint['teacher_state_dict'])
        
        # Evaluate on test set using teacher model
        test_loss, test_accuracy = self.evaluate(test_loader, use_teacher=True)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy*100:.2f}%")
        
        # Compute confusion matrix
        self.teacher_model.eval()
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Computing confusion matrix"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.teacher_model(inputs)
                _, preds = torch.max(outputs, 1)
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        # Generate classification report
        report = classification_report(
            all_targets, 
            all_preds, 
            target_names=[index_to_label[i] for i in range(self.num_classes)],
            digits=3
        )
        print("Classification Report:")
        print(report)
        
        # Plot confusion matrix
        cm = confusion_matrix(all_targets, all_preds)
        plt.figure(figsize=(15, 15))
        sns.heatmap(
            cm, 
            annot=True, 
            fmt='d', 
            cmap='Blues',
            xticklabels=[index_to_label[i] for i in range(self.num_classes)],
            yticklabels=[index_to_label[i] for i in range(self.num_classes)]
        )
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.tight_layout()
        plt.savefig('confusion_matrix.png')
        plt.show()
        
        return test_accuracy, report

In [None]:
# Initialize Safe Student trainer
trainer = SafeStudent(
    student_model=student_model,
    teacher_model=teacher_model,
    num_classes=len(labels),
    alpha=0.99,  # EMA decay rate
    temperature=0.5,  # Temperature for soft targets
    lambda_consistency=1.0  # Weight for consistency loss
)

In [None]:
# Train the model
print("Starting Safe Student training...")
num_epochs = 30  # Adjust as needed
best_accuracy, train_losses, val_accuracies = trainer.train(
    labeled_loader=labeled_loader,
    unlabeled_loader=unlabeled_loader,
    val_loader=val_loader,
    num_epochs=num_epochs
)

print(f"Training completed. Best validation accuracy: {best_accuracy*100:.2f}%")

In [None]:
# Test the model
test_accuracy, test_report = trainer.test(test_loader)
print(f"Final test accuracy: {test_accuracy*100:.2f}%")