# üîß MFCC Model - Fixed Retraining

This notebook fixes the issue where model predicts only Malayalam.

**Fixes:**
- Proper class balancing
- Better training parameters
- Validation checks
- Debug outputs

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_DIR = '/content/drive/MyDrive/IndicAccent_Project'
os.chdir(PROJECT_DIR)
print(f'‚úÖ Working directory: {os.getcwd()}')

In [None]:
!pip install -q datasets==3.0.1 torch torchaudio librosa soundfile scikit-learn matplotlib tqdm
print('‚úÖ Dependencies installed!')

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pickle
import glob
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

label_map = {
    0: "Telugu",
    1: "Tamil",
    2: "Malayalam",
    3: "Kannada",
    4: "Hindi",
    5: "Gujarati"
}

## üìä Step 1: Load and Check Data

In [None]:
# Load merged MFCC features
files = sorted(glob.glob(f"{PROJECT_DIR}/mfcc_chunks/*.pkl"))
print(f'Found {len(files)} chunk files')

X_all, y_all = [], []

for file in files:
    with open(file, "rb") as f:
        data = pickle.load(f)
        X_all.append(data["X"])
        y_all.append(data["y"])

X_all = np.vstack(X_all)
y_all = np.concatenate(y_all)

print(f'\n‚úÖ Data loaded')
print(f'   Features: {X_all.shape}')
print(f'   Labels: {y_all.shape}')

# CHECK CLASS DISTRIBUTION
print('\nüìä Class Distribution:')
for label in range(6):
    count = np.sum(y_all == label)
    print(f'   {label} ({label_map[label]:12}): {count:4d} samples ({count/len(y_all)*100:.1f}%)')

# Check for issues
if len(np.unique(y_all)) != 6:
    print('\n‚ö†Ô∏è WARNING: Not all 6 classes present in data!')
else:
    print('\n‚úÖ All 6 classes present')

## üß† Step 2: Define Model

In [None]:
class MFCCModel(nn.Module):
    def __init__(self, input_dim=80, num_classes=6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)

print('‚úÖ Model defined')

## üîÑ Step 3: Prepare Data with Stratification

In [None]:
# Split with stratification (ensures balanced classes)
X_train, X_val, y_train, y_val = train_test_split(
    X_all, y_all, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_all  # IMPORTANT: keeps class balance
)

print(f'Train samples: {len(X_train)}')
print(f'Val samples: {len(X_val)}')

# Check train distribution
print('\nTrain distribution:')
for label in range(6):
    count = np.sum(y_train == label)
    print(f'   {label_map[label]:12}: {count:4d} ({count/len(y_train)*100:.1f}%)')

# Check val distribution
print('\nValidation distribution:')
for label in range(6):
    count = np.sum(y_val == label)
    print(f'   {label_map[label]:12}: {count:4d} ({count/len(y_val)*100:.1f}%)')

# Dataset class
class MFCCDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, i):
        return self.X[i], self.y[i]

# DataLoaders
train_loader = DataLoader(MFCCDataset(X_train, y_train), batch_size=64, shuffle=True)
val_loader = DataLoader(MFCCDataset(X_val, y_val), batch_size=64)

print('\n‚úÖ Data loaders ready')

## üèãÔ∏è Step 4: Train with Monitoring

In [None]:
# Initialize
model = MFCCModel().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print('‚úÖ Model initialized')
print(f'   Parameters: {sum(p.numel() for p in model.parameters()):,}')

# Training loop with detailed monitoring
NUM_EPOCHS = 20
best_val_acc = 0.0

print(f'\nTraining for {NUM_EPOCHS} epochs...\n')

for epoch in range(NUM_EPOCHS):
    # Training
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for feats, labels in train_loader:
        feats, labels = feats.to(device), labels.to(device)
        
        optimizer.zero_grad()
        preds = model(feats)
        loss = loss_fn(preds, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_correct += (preds.argmax(dim=1) == labels).sum().item()
        train_total += labels.size(0)
    
    train_loss /= len(train_loader)
    train_acc = train_correct / train_total
    
    # Validation
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    
    # Track per-class accuracy
    class_correct = [0] * 6
    class_total = [0] * 6
    
    with torch.no_grad():
        for feats, labels in val_loader:
            feats, labels = feats.to(device), labels.to(device)
            preds = model(feats)
            loss = loss_fn(preds, labels)
            val_loss += loss.item()
            
            pred_labels = preds.argmax(dim=1)
            val_correct += (pred_labels == labels).sum().item()
            val_total += labels.size(0)
            
            # Per-class accuracy
            for i in range(len(labels)):
                label = labels[i].item()
                class_total[label] += 1
                if pred_labels[i] == labels[i]:
                    class_correct[label] += 1
    
    val_loss /= len(val_loader)
    val_acc = val_correct / val_total
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), f"{PROJECT_DIR}/mfcc_best_model_fixed.pt")
        best_marker = " üåü BEST"
    else:
        best_marker = ""
    
    print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}{best_marker}")
    
    # Show per-class accuracy every 5 epochs
    if (epoch + 1) % 5 == 0:
        print("  Per-class validation accuracy:")
        for i in range(6):
            if class_total[i] > 0:
                acc = class_correct[i] / class_total[i]
                print(f"    {label_map[i]:12}: {acc:.3f} ({class_correct[i]}/{class_total[i]})")
        print()

print(f'\n‚úÖ Training complete!')
print(f'   Best validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)')

## üß™ Step 5: Test Model Predictions

In [None]:
# Load best model
model.load_state_dict(torch.load(f"{PROJECT_DIR}/mfcc_best_model_fixed.pt"))
model.eval()

print("üß™ Testing model with random inputs...\n")

# Test 1: Random noise
print("Test 1: Random noise (should give varied predictions)")
for i in range(5):
    random_input = torch.randn(1, 80).to(device)
    with torch.no_grad():
        output = model(random_input)
        pred = output.argmax().item()
        prob = torch.softmax(output, dim=1)[0][pred].item()
    print(f"  {i+1}. Predicted: {label_map[pred]:12} ({prob*100:.1f}%)")

# Test 2: Real validation samples
print("\nTest 2: Real validation samples")
for i in range(5):
    idx = np.random.randint(0, len(X_val))
    sample = torch.tensor(X_val[idx], dtype=torch.float32).unsqueeze(0).to(device)
    true_label = y_val[idx]
    
    with torch.no_grad():
        output = model(sample)
        pred = output.argmax().item()
        prob = torch.softmax(output, dim=1)[0][pred].item()
    
    match = "‚úÖ" if pred == true_label else "‚ùå"
    print(f"  {i+1}. True: {label_map[true_label]:12} | Pred: {label_map[pred]:12} ({prob*100:.1f}%) {match}")

print("\n‚úÖ If you see varied predictions above, model is working!")
print("   If it always predicts Malayalam, there's still an issue.")

## üíæ Step 6: Save Final Model

In [None]:
# Save with clear name
torch.save(model.state_dict(), f"{PROJECT_DIR}/mfcc_best_model.pt")

print('‚úÖ Model saved!')
print(f'   Location: {PROJECT_DIR}/mfcc_best_model.pt')
print(f'   Backup: {PROJECT_DIR}/mfcc_best_model_fixed.pt')
print('\nüéâ Now use this model in your Gradio demo!')