In [1]:
# Cell 0: GPU Verification
import torch

print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"CUDA Device Count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current Device: {torch.cuda.current_device()}")
    print(f"Device Name: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU detected - using CPU")

CUDA Available: True
CUDA Device Count: 1
Current Device: 0
Device Name: NVIDIA GeForce RTX 3050 Laptop GPU


In [2]:
# Cell 1: Core Setup

import os
import torch
import numpy as np
import pandas as pd
import librosa
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForSequenceClassification,
    TrainingArguments,
    Trainer
)

In [3]:
# Cell 2: Global Settings
class Config:
    # Data settings
    expected_labels = ['angry', 'fear', 'happy', 'neutral', 'sad']
    audio_max_duration = 3  # seconds
    sample_rate = 16000
    
    # Model settings
    model_name = "facebook/wav2vec2-base"
    batch_size = 4
    learning_rate = 3e-5
    num_epochs = 30
    
    # Path handling
    base_path = Path("dataset")
    
config = Config()

In [4]:
# Cell 3: Data Loading & Cleaning
def load_and_validate_dataset(csv_path):
    """Load dataset with comprehensive validation"""
    try:
        # Detect header presence
        with open(csv_path, 'r') as f:
            first_line = f.readline().strip().lower()
            has_header = any(label in first_line for label in ['path', 'audio', 'label', 'emotion'])
        
        df = pd.read_csv(
            csv_path,
            header=0 if has_header else None,
            names=["audio_path", "label"]
        )
        
        # Clean paths
        df["audio_path"] = df["audio_path"].apply(
            lambda x: str(Path(x.replace("\\", os.sep).replace("/", os.sep)))
        )
        
        # Clean labels
        df["label"] = df["label"].str.strip().str.lower()
        df["label"] = df["label"].replace({'emotion': 'neutral'})  # Fix observed error
        
        # Validate labels
        invalid_labels = set(df["label"]) - set(config.expected_labels)
        if invalid_labels:
            raise ValueError(f"Invalid labels found: {invalid_labels}")
            
        # Check file existence
        missing_files = [p for p in df["audio_path"] if not Path(p).exists()]
        if missing_files:
            raise FileNotFoundError(f"Missing {len(missing_files)} audio files")
            
        return df
    
    except Exception as e:
        print(f"Error loading {csv_path}: {str(e)}")
        raise

# Load datasets
try:
    train_df = load_and_validate_dataset("train_dataset.csv")
    test_df = load_and_validate_dataset("test_dataset.csv")
    
    print("Train dataset:")
    print(train_df["label"].value_counts())
    print("\nTest dataset:")
    print(test_df["label"].value_counts())
    
except Exception as e:
    print("Failed to load datasets:")
    raise

Train dataset:
label
happy      134
neutral    134
sad        133
angry      127
fear        70
Name: count, dtype: int64

Test dataset:
label
neutral    42
happy      42
sad        42
angry      40
fear       22
Name: count, dtype: int64


In [5]:
# Cell 4: Dataset Pipeline
class TamilEmotionDataset(torch.utils.data.Dataset):
    def __init__(self, df, processor):
        self.df = df
        self.processor = processor
        self.max_length = config.audio_max_duration * config.sample_rate
        
        # Create label map
        self.label_map = {label: idx for idx, label in enumerate(config.expected_labels)}
        self.inverse_map = {v: k for k, v in self.label_map.items()}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            # Load metadata
            audio_path = self.df.iloc[idx]["audio_path"]
            label = self.df.iloc[idx]["label"]
            
            # Validate label
            if label not in self.label_map:
                raise ValueError(f"Invalid label {label}")
                
            # Load audio
            waveform, sr = librosa.load(
                audio_path,
                sr=config.sample_rate,
                mono=True,
                duration=config.audio_max_duration
            )
            
            # Validate audio
            if len(waveform) < 0.5 * sr:  # Minimum 0.5s
                raise ValueError("Audio too short")
                
            # Process features
            inputs = self.processor(
                waveform,
                sampling_rate=sr,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
                truncation=True
            )
            
            return {
                "input_values": inputs["input_values"].squeeze(),
                "labels": torch.tensor(self.label_map[label], dtype=torch.long)
            }
            
        except Exception as e:
            print(f"Skipping {audio_path}: {str(e)}")
            return None

def collate_fn(batch):
    """Handle invalid samples"""
    batch = [b for b in batch if b is not None]
    return {
        "input_values": torch.stack([b["input_values"] for b in batch]),
        "labels": torch.stack([b["labels"] for b in batch])
    }

In [6]:
# Cell 5 (Revised): Model Setup with Explicit GPU Handling
try:
    # Initialize processor
    processor = Wav2Vec2Processor.from_pretrained(config.model_name)
    
    # Create datasets
    train_dataset = TamilEmotionDataset(train_df, processor)
    test_dataset = TamilEmotionDataset(test_df, processor)
    
    # Model config
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    model = Wav2Vec2ForSequenceClassification.from_pretrained(
        config.model_name,
        num_labels=len(config.expected_labels)
    ).to(device)  # Explicit device placement
    
    print("\nModel device:", next(model.parameters()).device)
    
except Exception as e:
    print("Model initialization failed:")
    raise




Using device: cuda


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Model device: cuda:0


In [7]:
# Cell 6: Training Setup
training_args = TrainingArguments(
    output_dir="./ser_results",
    evaluation_strategy="epoch",
    learning_rate=config.learning_rate,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    num_train_epochs=config.num_epochs,
    logging_steps=50,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    fp16=torch.cuda.is_available(),  # Enable mixed precision if GPU available
    report_to="none",  # Disable external logging
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted")
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)



In [None]:
# Cell 7: Start Training
try:
    print("Starting training...")
    train_result = trainer.train()
    print("\nTraining completed!")
    print(f"Final metrics: {train_result.metrics}")
    
except RuntimeError as e:
    if "CUDA out of memory" in str(e):
        print("Memory error! Reduce batch size or model size")
    else:
        print("Training failed:")
    raise
    
except Exception as e:
    print("Unexpected error during training:")
    raise

Starting training...


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,1.5763,1.536637,0.303191,0.266273
2,1.4725,1.47891,0.324468,0.198373
3,1.5624,1.650946,0.255319,0.153715
4,1.5943,1.578707,0.228723,0.092748
5,1.4902,1.452146,0.319149,0.194253
6,1.4178,1.44338,0.303191,0.157717
7,1.4131,1.729651,0.207447,0.110654
8,1.4064,1.425386,0.303191,0.164089
9,1.633,1.610565,0.223404,0.081591
10,1.5809,1.528393,0.260638,0.139141
