In [29]:
import os
import pandas as pd
from pathlib import Path
import re
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from torchsummary import summary
from architectures.TimeCNN import TimeCNN
import IPython.display as ipd

In [30]:
# RAVDESS

path = 'datasets/archive/audio_speech_actors_01-24'

train_data = []
test_data = []
val_data = []

# Speaker split (gender-balanced)
train_speakers = ["01","03","05","07","09","11","13","15","17","02","04","06","08","10","12","14","16"]
val_speakers = ["19","20","18"]  # added validation split, took speakers form train split
test_speakers = ["21","23","22","24"]

emotion_map = {
    "01": "neutral",
    # "02": "calm",  -> remove not in crema-d
    "03": "happy",
    "04": "sad",
    "05": "angry",
    "06": "fearful",
    "07": "disgust",
    # "08": "surprised" -> remove not in crema-d
}

for folder in sorted(os.listdir(path)):
    folder_path = os.path.join(path, folder)
    for file in os.listdir(folder_path):
        match = re.match(r"(\d+)-(\d+)-(\d+)-(\d+)-(\d+)-(\d+)-(\d+)\.wav", file)
        if match:
            filepath = str(Path(path) / folder / file)
            indicators = match.groups()
            emotion_code = indicators[2]
            actor_id = indicators[-1]
            gender = 'Female' if int(indicators[-1]) % 2 == 0 else 'Male'
            if emotion_code not in emotion_map:
                continue

            record = {
                "path": filepath,
                "emotion": emotion_map[emotion_code],
                "gender": gender,
                "source": "RAVDESS"
            }

            if actor_id in train_speakers:
                train_data.append(record)
            elif actor_id in val_speakers:
                val_data.append(record)
            else:
                test_data.append(record)

train_df = pd.DataFrame(train_data)
val_df = pd.DataFrame(val_data)
test_df = pd.DataFrame(test_data)

train_df.to_csv("train_split.csv", index=False)
val_df.to_csv("val_split.csv", index=False)
test_df.to_csv("test_split.csv", index=False)

In [31]:
train_df.shape

(748, 4)

In [32]:
demo_df = pd.read_csv('datasets/CREMA-D/VideoDemographics.csv')
print(demo_df.head(10))
print(demo_df.columns)

   ActorID  Age     Sex              Race     Ethnicity
0     1001   51    Male         Caucasian  Not Hispanic
1     1002   21  Female         Caucasian  Not Hispanic
2     1003   21  Female         Caucasian  Not Hispanic
3     1004   42  Female         Caucasian  Not Hispanic
4     1005   29    Male  African American  Not Hispanic
5     1006   58  Female         Caucasian  Not Hispanic
6     1007   38  Female  African American  Not Hispanic
7     1008   46  Female         Caucasian  Not Hispanic
8     1009   24  Female         Caucasian  Not Hispanic
9     1010   27  Female         Caucasian  Not Hispanic
Index(['ActorID', 'Age', 'Sex', 'Race', 'Ethnicity'], dtype='object')


In [33]:
# CREMA-D

path = 'datasets/CREMA-D/AudioWAV'

# Load demographics for gender mapping
demo_df = pd.read_csv('datasets/CREMA-D/VideoDemographics.csv')
gender_map = dict(zip(demo_df['ActorID'], demo_df['Sex']))

# Check how many actors
print(f"Total actors: {len(gender_map)}")
print(f"Males: {list(gender_map.values()).count('Male')}")
print(f"Females: {list(gender_map.values()).count('Female')}")

Total actors: 91
Males: 48
Females: 43


In [34]:
train_data = []
test_data = []
val_data = []

emotion_map = {
    "ANG": "angry",
    "DIS": "disgust",
    "FEA": "fearful",
    "HAP": "happy",
    "NEU": "neutral",
    "SAD": "sad"
}

# Get all actor IDs and split by gender
all_actors = demo_df['ActorID'].tolist()
male_actors = demo_df[demo_df['Sex'] == 'Male']['ActorID'].tolist()
female_actors = demo_df[demo_df['Sex'] == 'Female']['ActorID'].tolist()

# 80/20 split per gender
train_males = male_actors[:34]      # 34 males for train
val_males = male_actors[34:38]      # 4 males for validation
test_males = male_actors[38:]       # 10 males for test

train_females = female_actors[:34]  # 34 females for train  
val_females = female_actors[34:38]  # 4 females for validation
test_females = female_actors[38:]   # 5 females for test

train_speakers = train_males + train_females  # 72 speakers
val_speakers = val_males + val_females        # 8 speakers
test_speakers = test_males + test_females      # 19 speakers

print(f"\nTrain speakers: {len(train_speakers)} ({len(train_males)}M, {len(train_females)}F)")
print(f"Validation speakers: {len(val_speakers)} ({len(val_males)}M, {len(val_females)}F)")
print(f"Test speakers: {len(test_speakers)} ({len(test_males)}M, {len(test_females)}F)")

# Parse audio files
for file in os.listdir(path):
    if not file.endswith('.wav'):
        continue
        
    # Filename: 1001_IEO_ANG_HI.wav
    parts = file.replace('.wav', '').split('_')
    
    if len(parts) < 4:
        continue
    
    actor_id = int(parts[0])
    emotion_code = parts[2]
    
    if emotion_code not in emotion_map:
        continue
    
    filepath = str(Path(path) / file)
    gender = gender_map.get(actor_id, 'Unknown')
    
    record = {
        "path": filepath,
        "emotion": emotion_map[emotion_code],
        "gender": gender,
        "source": "CREMA-D"
    }
    
    if actor_id in train_speakers:
        train_data.append(record)
    elif actor_id in val_speakers:
        val_data.append(record)
    elif actor_id in test_speakers:
        test_data.append(record)

# Create DataFrames
train_df_cremad = pd.DataFrame(train_data)
val_df_cremad = pd.DataFrame(val_data)
test_df_cremad = pd.DataFrame(test_data)

# Verify
print("\nCREMA-D TRAIN:")
print(f"  Samples: {len(train_df_cremad)}")
print(f"  By gender: {train_df_cremad['gender'].value_counts().to_dict()}")

print("\nCREMA-D VALIDATION:")
print(f"  Samples: {len(val_df_cremad)}")
print(f" by gender: {val_df_cremad['gender'].value_counts().to_dict()}")

print("\nCREMA-D TEST:")
print(f"  Samples: {len(test_df_cremad)}")
print(f"  By gender: {test_df_cremad['gender'].value_counts().to_dict()}")


Train speakers: 68 (34M, 34F)
Validation speakers: 8 (4M, 4F)
Test speakers: 15 (10M, 5F)

CREMA-D TRAIN:
  Samples: 5557
  By gender: {'Male': 2782, 'Female': 2775}

CREMA-D VALIDATION:
  Samples: 655
 by gender: {'Male': 328, 'Female': 327}

CREMA-D TEST:
  Samples: 1230
  By gender: {'Male': 820, 'Female': 410}


In [35]:
train_df_combined = pd.concat([train_df, train_df_cremad], ignore_index=True)
val_df_combined = pd.concat([val_df, val_df_cremad], ignore_index=True)
test_df_combined = pd.concat([test_df, test_df_cremad], ignore_index=True)

train_df_combined.to_csv("train_split.csv", index=False)
val_df_combined.to_csv("val_split.csv", index=False)
test_df_combined.to_csv("test_split.csv", index=False)

print(f"\nFINAL TRAIN: {len(train_df_combined)}")
print(f"FINAL VALIDATION: {len(val_df_combined)}")
print(f"FINAL TEST: {len(test_df_combined)}")


FINAL TRAIN: 6305
FINAL VALIDATION: 787
FINAL TEST: 1406


In [36]:
# quick check to ensure "calm" emotion is removed
print(train_df['emotion'].unique())
print(val_df['emotion'].unique())
print(test_df['emotion'].unique())


['neutral' 'happy' 'sad' 'angry' 'fearful' 'disgust']
['neutral' 'happy' 'sad' 'angry' 'fearful' 'disgust']
['neutral' 'happy' 'sad' 'angry' 'fearful' 'disgust']


In [37]:
# sample_path = train_df.iloc[0]['path']
# sample_emotion = train_df.iloc[0]['emotion']
# print (f"Sample path: {sample_path},\nemotion: {sample_emotion}")

In [38]:

# sample_rate_1 = 16000
# sample_rate_2 = 22050
# duration_1 = 3
# duration_2 = 5


# waveform_1, sr_1 = librosa.load(sample_path, sr=sample_rate_1)
# print(f"\nOriginal_1:")
# print(f"  Sample rate: {sr_1} Hz")
# print(f"  Shape: {waveform_1.shape}")
# print(f"  Duration: {len(waveform_1)/sr_1:.2f} seconds")

# waveform_2, sr_2 = librosa.load(sample_path, sr=sample_rate_2)
# print(f"\nOriginal_2:")
# print(f"  Sample rate: {sr_2} Hz")
# print(f"  Shape: {waveform_2.shape}")
# print(f"  Duration: {len(waveform_2)/sr_2:.2f} seconds")

In [39]:
# target_length_1 = sample_rate_1 * duration_1  # 3 seconds at 16kHz
# target_length_2 = sample_rate_2 * duration_2  # 5 seconds at 22.05kHz

# waveform_1 = librosa.util.fix_length(waveform_1, size=target_length_1)
# waveform_2 = librosa.util.fix_length(waveform_2, size=target_length_2)

# print(f"Final waveform_1 shape: {waveform_1.shape}")
# print(f"Final waveform_2 shape: {waveform_2.shape}")

In [40]:
# mel_spec_1 = librosa.feature.melspectrogram(
#     y=waveform_1,
#     sr=sample_rate_1,
#     n_mels = 128,
#     n_fft=2048,
#     hop_length=512
# )
# print(f"\nMel-spectrogram shape: {mel_spec_1.shape}")

# mel_spec_2 = librosa.feature.melspectrogram(
#     y=waveform_2,
#     sr=sample_rate_2,
#     n_mels = 128,
#     n_fft=2048,
#     hop_length=512
# )
# print(f"\nMel-spectrogram shape: {mel_spec_2.shape}")

In [41]:
# mel_spec_db_1 = librosa.power_to_db(mel_spec_1, ref=np.max)
# print(f"Mel-spectrogram (dB) shape: {mel_spec_db_1.shape}")
# mel_spec_db_2 = librosa.power_to_db(mel_spec_2, ref=np.max)
# print(f"Mel-spectrogram (dB) shape: {mel_spec_db_2.shape}")

In [42]:
# # ============== PLOT COMPARISON ==============
# fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# # Row 1: Waveforms
# axes[0, 0].plot(waveform_1)
# axes[0, 0].set_title(f'Waveform - 16kHz, 3 sec\nSamples: {len(waveform_1):,}')
# axes[0, 0].set_xlabel('Samples')
# axes[0, 0].set_ylabel('Amplitude')

# axes[0, 1].plot(waveform_2)
# axes[0, 1].set_title(f'Waveform - 22kHz, 5 sec\nSamples: {len(waveform_2):,}')
# axes[0, 1].set_xlabel('Samples')
# axes[0, 1].set_ylabel('Amplitude')

# # Row 2: Mel-Spectrograms (dB)
# img1 = librosa.display.specshow(
#     mel_spec_db_1, x_axis='time', y_axis='mel', sr=sample_rate_1, ax=axes[1, 0]
# )
# axes[1, 0].set_title(f'Mel-Spectrogram (YOUR SETTINGS)\nShape: {mel_spec_db_1.shape}')
# fig.colorbar(img1, ax=axes[1, 0], format='%+2.0f dB')

# img2 = librosa.display.specshow(
#     mel_spec_db_2, x_axis='time', y_axis='mel', sr=sample_rate_2, ax=axes[1, 1]
# )
# axes[1, 1].set_title(f'Mel-Spectrogram (REFERENCE SETTINGS)\nShape: {mel_spec_db_2.shape}')
# fig.colorbar(img2, ax=axes[1, 1], format='%+2.0f dB')

# plt.suptitle(f'Comparison: {sample_emotion.upper()}', fontsize=14, fontweight='bold')
# plt.tight_layout()
# plt.savefig('audio_comparison.png', dpi=150)
# plt.show()

# # ============== SUMMARY ==============
# print("\n" + "=" * 60)
# print("COMPARISON SUMMARY")
# print("=" * 60)
# print(f"\n{'Setting':<25} {'Yours':<20} {'Reference':<20}")
# print("-" * 60)
# print(f"{'Sample Rate':<25} {sample_rate_1:,} Hz{'':<10} {sample_rate_2:,} Hz")
# print(f"{'Duration':<25} {duration_1} sec{'':<15} {duration_2} sec")
# print(f"{'Total Samples':<25} {len(waveform_1):,}{'':<13} {len(waveform_2):,}")
# print(f"{'Spectrogram Shape':<25} {mel_spec_db_1.shape}{'':<11} {mel_spec_db_2.shape}")
# print(f"{'Time Frames':<25} {mel_spec_db_1.shape[1]}{'':<17} {mel_spec_db_2.shape[1]}")
# print("-" * 60)
# print(f"\nReference has {mel_spec_db_2.shape[1] / mel_spec_db_1.shape[1]:.1f}x more temporal information!")

128 mel bands - frequency range leading up to 48000hz (3 sec time)

94 time windows

hence the shape (128, 94)

notice the power db graph (scale adjusted for human audible samples) - we set the loudest volume (highest amplitude) to 0, closest to that will be loude voice represented with bright colors and then leading away from it is quiter around -80 db

In [43]:
# def extract_time_features(waveform, sr):
#     # Zero Crossing Rate
#     zcr = librosa.feature.zero_crossing_rate(
#         waveform, frame_length=2048, hop_length=512
#     )
    
#     # RMS Energy
#     energy = librosa.feature.rms(
#         y=waveform, frame_length=2048, hop_length=512
#     )
    
#     # MFCCs
#     mfccs = librosa.feature.mfcc(
#         y=waveform, sr=sr, n_fft=2048, hop_length=512, n_mfcc=13
#     )
    
#     return zcr, energy, mfccs

In [44]:
# angry_sample  = train_df[train_df['emotion'] == 'angry'].iloc[0]
# sad_sample = train_df[train_df['emotion'] == 'sad'].iloc[0]

# sample_rate = 22050
# duration = 5
# target_length = sample_rate * duration

# waveform_angry, sample_rate_angry = librosa.load(angry_sample['path'], sr=sample_rate, mono=True)
# waveform_angry = librosa.util.fix_length(waveform_angry, size=target_length)

# waveform_sad, sample_rate_sad = librosa.load(sad_sample['path'], sr=sample_rate, mono=True)
# waveform_sad = librosa.util.fix_length(waveform_sad, size=target_length)

# zcr_angry, energy_angry, mfccs_angry = extract_time_features(waveform_angry, sample_rate)
# zcr_sad, energy_sad, mfccs_sad = extract_time_features(waveform_sad, sample_rate)

In [45]:
# print("ANGRY:")
# ipd.display(ipd.Audio(angry_sample['path']))

# # Play sad sample
# print("SAD:")
# ipd.display(ipd.Audio(sad_sample['path']))

In [46]:
# print(f"\nFeature Shapes:")
# print(f"  ZCR:    {zcr_angry.shape}")
# print(f"  Energy: {energy_angry.shape}")
# print(f"  MFCCs:  {mfccs_angry.shape}")
# print(f"  Combined: ({1 + 1 + 13}, {zcr_angry.shape[1]}) = (15, {zcr_angry.shape[1]})")

In [47]:
# # ============== PLOT COMPARISON ==============
# fig, axes = plt.subplots(4, 2, figsize=(16, 14))

# # Column 0: ANGRY | Column 1: SAD

# # Row 0: Waveform
# axes[0, 0].plot(waveform_angry, color='red', alpha=0.7)
# axes[0, 0].set_title('ANGRY - Waveform', fontsize=12, fontweight='bold')
# axes[0, 0].set_xlabel('Samples')
# axes[0, 0].set_ylabel('Amplitude')

# axes[0, 1].plot(waveform_sad, color='blue', alpha=0.7)
# axes[0, 1].set_title('SAD - Waveform', fontsize=12, fontweight='bold')
# axes[0, 1].set_xlabel('Samples')
# axes[0, 1].set_ylabel('Amplitude')

# # Row 1: Zero Crossing Rate
# axes[1, 0].plot(zcr_angry[0], color='red', alpha=0.7)
# axes[1, 0].set_title(f'ANGRY - Zero Crossing Rate\nMean: {zcr_angry.mean():.4f}', fontsize=11)
# axes[1, 0].set_xlabel('Time Frames')
# axes[1, 0].set_ylabel('ZCR')
# axes[1, 0].set_ylim(0, 0.3)

# axes[1, 1].plot(zcr_sad[0], color='blue', alpha=0.7)
# axes[1, 1].set_title(f'SAD - Zero Crossing Rate\nMean: {zcr_sad.mean():.4f}', fontsize=11)
# axes[1, 1].set_xlabel('Time Frames')
# axes[1, 1].set_ylabel('ZCR')
# axes[1, 1].set_ylim(0, 0.3)

# # Row 2: RMS Energy
# axes[2, 0].plot(energy_angry[0], color='red', alpha=0.7)
# axes[2, 0].fill_between(range(len(energy_angry[0])), energy_angry[0], alpha=0.3, color='red')
# axes[2, 0].set_title(f'ANGRY - RMS Energy\nMean: {energy_angry.mean():.4f}', fontsize=11)
# axes[2, 0].set_xlabel('Time Frames')
# axes[2, 0].set_ylabel('Energy')

# axes[2, 1].plot(energy_sad[0], color='blue', alpha=0.7)
# axes[2, 1].fill_between(range(len(energy_sad[0])), energy_sad[0], alpha=0.3, color='blue')
# axes[2, 1].set_title(f'SAD - RMS Energy\nMean: {energy_sad.mean():.4f}', fontsize=11)
# axes[2, 1].set_xlabel('Time Frames')
# axes[2, 1].set_ylabel('Energy')

# # Row 3: MFCCs
# img1 = librosa.display.specshow(mfccs_angry, x_axis='time', sr=sample_rate, hop_length=512, ax=axes[3, 0])
# axes[3, 0].set_title('ANGRY - MFCCs (13 coefficients)', fontsize=11)
# axes[3, 0].set_ylabel('MFCC Coefficient')
# fig.colorbar(img1, ax=axes[3, 0], format='%+2.0f')

# img2 = librosa.display.specshow(mfccs_sad, x_axis='time', sr=sample_rate, hop_length=512, ax=axes[3, 1])
# axes[3, 1].set_title('SAD - MFCCs (13 coefficients)', fontsize=11)
# axes[3, 1].set_ylabel('MFCC Coefficient')
# fig.colorbar(img2, ax=axes[3, 1], format='%+2.0f')

# plt.suptitle('Time-Domain Features: ANGRY vs SAD', fontsize=14, fontweight='bold')
# plt.tight_layout()
# plt.savefig('time_domain_comparison.png', dpi=150)
# plt.show()

In [48]:
class EmotionDataset(Dataset):
    def __init__(self, csv_path, augment=False):
        self.df = pd.read_csv(csv_path)
        self.augment = augment
        
        # Emotion to number mapping
        self.emotion_to_idx = {
            'angry': 0,
            'disgust': 1,
            'fearful': 2,
            'happy': 3,
            'neutral': 4,
            'sad': 5
        }
        
        # Audio settings (UPDATED to match reference)
        self.sample_rate = 22050    # Changed from 16000
        self.duration = 5           # Changed from 3
        self.target_length = self.sample_rate * self.duration  # 110250
        
        # Mel-spectrogram settings
        self.n_mels = 128
        self.n_fft = 2048
        self.hop_length = 512
        self.n_mfcc = 13  # Number of MFCC coefficients
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # 1. Load audio
        waveform, sr = librosa.load(row['path'], sr=self.sample_rate, mono=True)
        
        # 2. Fix length to exactly 5 seconds (110250 samples at 22.05kHz)
        if len(waveform) < self.target_length:
            padding = self.target_length - len(waveform)
            offset = padding // 2
            waveform = np.pad(waveform, (offset, padding - offset), 'constant')
        else:
            waveform = waveform[:self.target_length]        
        
        if self.augment:
            # Time shift
            shift = np.random.randint(-8000, 8000)
            waveform = np.roll(waveform, shift)
            
            # Add noise
            noise = np.random.normal(0, 0.005, waveform.shape)
            waveform = waveform + noise
            
            # Random volume
            volume = np.random.uniform(0.8, 1.2)
            waveform = waveform * volume
    
        # 4. Extract time-domain features
        # Zero Crossing Rate: (1, 216)
        zcr = librosa.feature.zero_crossing_rate(
            waveform, frame_length=self.n_fft, hop_length=self.hop_length
        )
        
        # RMS Energy: (1, 216)
        energy = librosa.feature.rms(
            y=waveform, frame_length=self.n_fft, hop_length=self.hop_length
        )
        
        # MFCCs: (13, 216)
        mfccs = librosa.feature.mfcc(
            y=waveform, sr=self.sample_rate, 
            n_fft=self.n_fft, hop_length=self.hop_length, n_mfcc=self.n_mfcc
        )

        # 5. Stack features: (15, 216)
        features = np.vstack([zcr, energy, mfccs])

        # 7. Convert to tensor
        # Shape: (15, 216) - no channel dimension needed for 1D CNN
        features_tensor = torch.FloatTensor(features)
        
        # 8. Get label
        label = self.emotion_to_idx[row['emotion']]
        
        return features_tensor, label

In [49]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        
        return self.early_stop

# Testing starts here

In [50]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using device: {device}")

using device: cuda


In [51]:
# Datasets
train_dataset = EmotionDataset('train_split.csv')
validation_dataset = EmotionDataset('val_split.csv')  # No augmentation for validation
test_dataset = EmotionDataset('test_split.csv')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [52]:
validation_Csv = pd.read_csv('val_split.csv')
print(validation_Csv["emotion"].value_counts())

emotion
happy      136
sad        136
fearful    136
angry      136
disgust    136
neutral    107
Name: count, dtype: int64


In [53]:
print(f"Train: {len(train_dataset)} samples")
print(f"Validation: {len(validation_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

Train: 6305 samples
Validation: 787 samples
Test: 1406 samples


In [54]:
df = pd.read_csv('val_split.csv')
print(df["emotion"].value_counts())

emotion
happy      136
sad        136
fearful    136
angry      136
disgust    136
neutral    107
Name: count, dtype: int64


In [55]:
def train_model(model, train_loader, validation_loader, epochs=50, lr=0.001, patience=10, device='cuda', weight_decay=1e-4):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    early_stopping = EarlyStopping(patience=patience)
    
    best_acc = 0.0
    best_model_state = None
    history = {'train_acc': [], 'val_acc': [], 'train_loss': [], 'val_loss': []}
    
    for epoch in range(epochs):
        
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        train_acc = 100 * train_correct / train_total
        train_loss = train_loss / len(train_loader)
        
        # Testing
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in validation_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_acc = 100 * val_correct / val_total
        val_loss = val_loss / len(validation_loader)
        scheduler.step(val_loss)
        
        # Save history
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)

        print(f"Epoch {epoch+1}/{epochs} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}% | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_state = model.state_dict().copy()
        
        # Early stopping check
        if early_stopping(val_loss):
            print(f"  Early stopping at epoch {epoch+1}")
            break
    
    # Restore best model
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    return best_acc, history

In [None]:
model = TimeCNN(num_classes=6).to(device)
            
# model summary
stats = summary(model, input_size=(15, 216))

# Train
best_acc, history = train_model(
    model=model,
    train_loader=train_loader,
    validation_loader=validation_loader,
    epochs=50,
    lr=1e-4,
    patience=10,
    device=device,
    weight_decay=1e-4
)

print(f"Best Accuracy: {best_acc:.2f}%")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 256, 216]          11,776
       BatchNorm1d-2             [-1, 256, 216]             512
              ReLU-3             [-1, 256, 216]               0
         AvgPool1d-4             [-1, 256, 108]               0
            Conv1d-5             [-1, 256, 108]         196,864
       BatchNorm1d-6             [-1, 256, 108]             512
              ReLU-7             [-1, 256, 108]               0
            Conv1d-8             [-1, 256, 108]         327,936
       BatchNorm1d-9             [-1, 256, 108]             512
             ReLU-10             [-1, 256, 108]               0
           Conv1d-11             [-1, 128, 108]          98,432
      BatchNorm1d-12             [-1, 128, 108]             256
             ReLU-13             [-1, 128, 108]               0
AdaptiveAvgPool1d-14               [-1,



Epoch 1/50 | Train Acc: 37.32% | Val Acc: 37.74% | Train Loss: 1.5072 | Val Loss: 1.5232
Epoch 2/50 | Train Acc: 43.58% | Val Acc: 38.12% | Train Loss: 1.4105 | Val Loss: 1.4333
Epoch 3/50 | Train Acc: 46.47% | Val Acc: 42.95% | Train Loss: 1.3420 | Val Loss: 1.4486
Epoch 4/50 | Train Acc: 48.75% | Val Acc: 44.35% | Train Loss: 1.2922 | Val Loss: 1.3822
Epoch 5/50 | Train Acc: 49.90% | Val Acc: 40.91% | Train Loss: 1.2616 | Val Loss: 1.4880


KeyboardInterrupt: 