In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import random
from tqdm import tqdm

In [None]:
# Set device and seeds for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

print(f"Using device: {device}")

In [None]:
from torchaudio.datasets import SPEECHCOMMANDS

class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f:
                return [os.path.join(self._path, line.strip()) for line in f]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]
            
        print(f"Loaded {subset} set with {len(self._walker)} samples")

In [None]:
# Load datasets
train_set = SubsetSC("training")
val_set = SubsetSC("validation")
test_set = SubsetSC("testing")

# Get list of labels
labels = sorted(list(set(item[2] for item in train_set)))
label_to_index = {label: i for i, label in enumerate(labels)}
index_to_label = {i: label for i, label in enumerate(labels)}

print(f"Number of classes: {len(labels)}")
print(f"Labels: {labels}")

In [None]:
# Audio preprocessing
class AudioPreprocessor:
    def __init__(self, sample_rate=16000, n_mfcc=40, n_fft=400, hop_length=160):
        self.sample_rate = sample_rate
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # MFCC transform
        self.mfcc_transform = T.MFCC(
            sample_rate=sample_rate,
            n_mfcc=n_mfcc,
            melkwargs={
                'n_fft': n_fft,
                'hop_length': hop_length,
                'n_mels': 80,
                'center': False
            }
        )
        
    def __call__(self, waveform, sample_rate):
        # Resample if needed
        if sample_rate != self.sample_rate:
            resampler = T.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
            waveform = resampler(waveform)
            
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
            
        # Pad or trim to 1 second (16000 samples)
        if waveform.shape[1] < self.sample_rate:
            waveform = F.pad(waveform, (0, self.sample_rate - waveform.shape[1]))
        else:
            waveform = waveform[:, :self.sample_rate]
            
        # Extract MFCC features
        mfcc = self.mfcc_transform(waveform)
        
        # Check for NaN or Inf values
        if torch.isnan(mfcc).any() or torch.isinf(mfcc).any():
            mfcc = torch.nan_to_num(mfcc, nan=0.0, posinf=1.0, neginf=-1.0)
            
        return mfcc

In [None]:
# Data augmentation
class AudioAugmentation:
    def __init__(self):
        self.time_mask = T.TimeMasking(time_mask_param=10)
        self.freq_mask = T.FrequencyMasking(freq_mask_param=10)
        
    def weak_augment(self, x):
        # Apply mild augmentation
        x = self.freq_mask(x)
        return x
    
    def strong_augment(self, x):
        # Apply stronger augmentation
        x = self.freq_mask(x)
        x = self.time_mask(x)
        # Add some noise
        noise = torch.randn_like(x) * 0.1
        x = x + noise
        return x

In [None]:
# Custom dataset for semi-supervised learning
class SpeechCommandDataset(Dataset):
    def __init__(self, dataset, preprocessor, label_to_index):
        self.dataset = dataset
        self.preprocessor = preprocessor
        self.label_to_index = label_to_index
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        waveform, sample_rate, label, _, _ = self.dataset[idx]
        
        # Preprocess audio
        features = self.preprocessor(waveform, sample_rate)
        
        # Convert label to index
        label_idx = self.label_to_index[label]
        
        return features, label_idx

In [None]:
# Create preprocessor
preprocessor = AudioPreprocessor()

# Create datasets with preprocessing
train_dataset = SpeechCommandDataset(train_set, preprocessor, label_to_index)
val_dataset = SpeechCommandDataset(val_set, preprocessor, label_to_index)
test_dataset = SpeechCommandDataset(test_set, preprocessor, label_to_index)

# Split training set into labeled and unlabeled
num_labeled = 1000  # Number of labeled examples to use
total_samples = len(train_dataset)

# Create indices for the split
indices = list(range(total_samples))
random.shuffle(indices)
labeled_indices = indices[:num_labeled]
unlabeled_indices = indices[num_labeled:]

# Create labeled and unlabeled datasets
labeled_dataset = Subset(train_dataset, labeled_indices)
unlabeled_dataset = Subset(train_dataset, unlabeled_indices)

print(f"Labeled samples: {len(labeled_dataset)}")
print(f"Unlabeled samples: {len(unlabeled_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Create data loaders
batch_size = 64
labeled_batch_size = 16  # Per batch

labeled_loader = DataLoader(
    labeled_dataset,
    batch_size=labeled_batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

unlabeled_loader = DataLoader(
    unlabeled_dataset,
    batch_size=batch_size - labeled_batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# Check a batch from each loader to verify shapes
for name, loader in [("Labeled", labeled_loader), ("Unlabeled", unlabeled_loader), ("Val", val_loader)]:
    inputs, labels = next(iter(loader))
    print(f"{name} batch - inputs: {inputs.shape}, labels: {labels.shape}")

In [None]:
# Improved CNN model with residual connections
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class AudioResNet(nn.Module):
    def __init__(self, num_classes):
        super(AudioResNet, self).__init__()
        
        # Initial convolutional layer
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        # Residual blocks
        self.layer1 = self._make_layer(32, 32, 2, stride=1)
        self.layer2 = self._make_layer(32, 64, 2, stride=2)
        self.layer3 = self._make_layer(64, 128, 2, stride=2)
        self.layer4 = self._make_layer(128, 256, 2, stride=2)
        
        # Global average pooling and fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)
        
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Initial layer
        x = F.relu(self.bn1(self.conv1(x)))
        
        # Residual blocks
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Global average pooling
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        
        # Dropout and final layer
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

In [None]:
# Initialize model
model = AudioResNet(num_classes=len(labels)).to(device)
print(model)

In [None]:
# Improved ReFixMatch training function
class ReFixMatch:
    def __init__(self, model, num_classes, lambda_u=1.0, threshold=0.95):
        self.model = model
        self.num_classes = num_classes
        self.lambda_u = lambda_u
        self.threshold = threshold
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=10, T_mult=2)
        self.augmentation = AudioAugmentation()
        
    def train_epoch(self, labeled_loader, unlabeled_loader, epoch):
        self.model.train()
        total_loss = 0.0
        total_sup_loss = 0.0
        total_unsup_loss = 0.0
        correct = 0
        total = 0
        
        # Create iterators for the loaders
        labeled_iter = iter(labeled_loader)
        unlabeled_iter = iter(unlabeled_loader)
        
        # Determine number of batches
        num_batches = min(len(labeled_loader), len(unlabeled_loader))
        
        # Progress bar
        pbar = tqdm(range(num_batches), desc=f"Epoch {epoch+1}")
        
        for batch_idx in pbar:
            # Get labeled batch
            try:
                inputs_x, targets_x = next(labeled_iter)
            except StopIteration:
                labeled_iter = iter(labeled_loader)
                inputs_x, targets_x = next(labeled_iter)
            
            # Get unlabeled batch
            try:
                inputs_u, _ = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = iter(unlabeled_loader)
                inputs_u, _ = next(unlabeled_iter)
            
            # Move to device
            inputs_x, targets_x = inputs_x.to(device), targets_x.to(device)
            inputs_u = inputs_u.to(device)
            
            batch_size = inputs_x.size(0)
            
            # Generate pseudo-labels using the model and weak augmentation
            with torch.no_grad():
                # Apply weak augmentation
                inputs_u_weak = torch.stack([self.augmentation.weak_augment(x.unsqueeze(0)).squeeze(0) for x in inputs_u])
                
                # Get model predictions
                outputs_u = self.model(inputs_u_weak)
                probs_u = torch.softmax(outputs_u, dim=1)
                
                # Get pseudo-labels and mask based on confidence threshold
                max_probs, pseudo_labels = torch.max(probs_u, dim=1)
                mask = max_probs.ge(self.threshold).float()
            
            # Apply strong augmentation to unlabeled data
            inputs_u_strong = torch.stack([self.augmentation.strong_augment(x.unsqueeze(0)).squeeze(0) for x in inputs_u])
            
            # Forward pass for all inputs
            logits_x = self.model(inputs_x)
            logits_u = self.model(inputs_u_strong)
            
            # Supervised loss
            loss_x = self.criterion(logits_x, targets_x)
            
            # Unsupervised loss with pseudo-labels (only if mask has positive values)
            if mask.sum() > 0:
                loss_u = (F.cross_entropy(logits_u, pseudo_labels, reduction='none') * mask).mean()
            else:
                loss_u = torch.tensor(0.0).to(device)
            
            # Ramp up unsupervised loss weight
            # Start with a lower weight and gradually increase it
            rampup_length = 50  # epochs
            current_lambda_u = self.lambda_u * min(1.0, (epoch + batch_idx/num_batches) / rampup_length)
            
            # Total loss
            loss = loss_x + current_lambda_u * loss_u
            
            # Check for NaN loss
            if torch.isnan(loss):
                print("NaN loss detected!")
                print(f"loss_x: {loss_x}, loss_u: {loss_u}")
                print(f"logits_x min/max: {logits_x.min()}/{logits_x.max()}")
                print(f"logits_u min/max: {logits_u.min()}/{logits_u.max()}")
                # Skip this batch
                continue
            
            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # Update metrics
            total_loss += loss.item()
            total_sup_loss += loss_x.item()
            total_unsup_loss += loss_u.item()
            
            # Calculate accuracy for supervised data
            _, predicted = torch.max(logits_x, 1)
            total += targets_x.size(0)
            correct += (predicted == targets_x).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'sup_loss': f"{loss_x.item():.4f}",
                'unsup_loss': f"{loss_u.item():.4f}",
                'mask': f"{mask.mean().item():.2f}",
                'acc': f"{100 * correct / total:.2f}%"
            })
        
        # Update learning rate
        self.scheduler.step()
        
        avg_loss = total_loss / num_batches
        avg_sup_loss = total_sup_loss / num_batches
        avg_unsup_loss = total_unsup_loss / num_batches
        accuracy = 100 * correct / total
        
        return avg_loss, avg_sup_loss, avg_unsup_loss, accuracy
    
    def evaluate(self, test_loader):
        self.model.eval()
        all_preds = []
        all_targets = []
        total_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Evaluating"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                total_loss += loss.item()
                
                _, preds = torch.max(outputs, 1)
                
                # Update accuracy
                total += targets.size(0)
                correct += (preds == targets).sum().item()
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        avg_loss = total_loss / len(test_loader)
        accuracy = correct / total
        
        return avg_loss, accuracy
    
    def train(self, labeled_loader, unlabeled_loader, val_loader, num_epochs=100):
        best_accuracy = 0.0
        train_losses = []
        val_accuracies = []
        
        for epoch in range(num_epochs):
            # Train for one epoch
            train_loss, sup_loss, unsup_loss, train_acc = self.train_epoch(
                labeled_loader, unlabeled_loader, epoch
            )
            train_losses.append(train_loss)
            
            # Evaluate on validation set
            val_loss, val_accuracy = self.evaluate(val_loader)
            val_accuracies.append(val_accuracy)
            
            print(f"Epoch {epoch+1}/{num_epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Sup Loss: {sup_loss:.4f} | "
                  f"Unsup Loss: {unsup_loss:.4f} | "
                  f"Train Acc: {train_acc:.2f}% | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val Acc: {val_accuracy*100:.2f}%")
            
            # Save best model
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'accuracy': val_accuracy,
                }, 'GSC_ReFix.pt')
                print(f"New best model saved with accuracy: {val_accuracy*100:.2f}%")
        
        # Plot training progress
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(train_losses)
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        
        plt.subplot(1, 2, 2)
        plt.plot(val_accuracies)
        plt.title('Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        
        plt.tight_layout()
        plt.savefig('training_progress.png')
        plt.show()
        
        return best_accuracy, train_losses, val_accuracies
    
    def test(self, test_loader):
        # Load best model
        checkpoint = torch.load('GSC_ReFix.pt')
        self.model.load_state_dict(checkpoint['model_state_dict'])
        
        # Evaluate on test set
        test_loss, test_accuracy = self.evaluate(test_loader)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy*100:.2f}%")
        
        # Compute confusion matrix
        self.model.eval()
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Computing confusion matrix"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                _, preds = torch.max(outputs, 1)
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        # Get unique labels that actually appear in the test set
        unique_labels = sorted(set(all_targets))
        label_names = [index_to_label.get(i, f"Unknown-{i}") for i in unique_labels]
        
        # Generate classification report
        report = classification_report(
            all_targets, 
            all_preds, 
            labels=unique_labels,
            target_names=label_names,
            digits=3
        )
        print("Classification Report:")
        print(report)
        
        # Plot confusion matrix
        cm = confusion_matrix(all_targets, all_preds, labels=unique_labels)
        plt.figure(figsize=(15, 15))
        sns.heatmap(
            cm, 
            annot=True, 
            fmt='d', 
            cmap='Blues',
            xticklabels=label_names,
            yticklabels=label_names
        )
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.tight_layout()
        plt.savefig('confusion_matrix.png')
        plt.show()
        
        return test_accuracy, report

In [None]:
# Initialize ReFixMatch trainer with improved parameters
trainer = ReFixMatch(
    model=model,
    num_classes=len(labels),
    lambda_u=3.0,  # Increased weight for unsupervised loss
    threshold=0.8  # Lowered confidence threshold for pseudo-labeling
)

In [None]:
# Train the model
print("Starting ReFixMatch training...")
num_epochs = 100  # Increased number of epochs
best_accuracy, train_losses, val_accuracies = trainer.train(
    labeled_loader=labeled_loader,
    unlabeled_loader=unlabeled_loader,
    val_loader=val_loader,
    num_epochs=num_epochs
)

print(f"Training completed. Best validation accuracy: {best_accuracy*100:.2f}%")

In [None]:
# Test the model
test_accuracy, test_report = trainer.test(test_loader)
print(f"Final test accuracy: {test_accuracy*100:.2f}%")