In [None]:
# WNN for processed_data2

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pywt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight

# === CONFIGURATION ===
data_dir = '/storage/projects1/e19-4yp-mi-eeg-for-bci/ashan/processed_data2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 20
batch_size = 64

# === Load File List and Labels ===
npz_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.npz')])

all_labels = []
for file in npz_files:
    y = np.load(file)['y']
    all_labels.extend(y)
all_labels = np.array(all_labels)

classes = np.unique(all_labels)
class_weights_array = compute_class_weight(class_weight='balanced', classes=classes, y=all_labels)
class_weights_tensor = torch.tensor(class_weights_array, dtype=torch.float32).to(device)

# === Train-Test Split ===
train_files, test_files = train_test_split(npz_files, test_size=0.2, random_state=42, shuffle=True)

# === Apply Wavelet Transform ===
def apply_wavelet(X):
    # X shape: (n_trials, n_channels, n_times)
    X_wavelet = []
    for trial in X:
        trial_coeffs = []
        for channel in trial:
            coeffs, _ = pywt.dwt(channel, 'db4')  # single-level DWT
            trial_coeffs.append(coeffs)
        X_wavelet.append(np.array(trial_coeffs))
    return np.array(X_wavelet)

# === Fit Scaler on Sample Training Data ===
scaler = StandardScaler()
sample_data = []
for file in train_files[:5]:
    X = np.load(file)['X']
    X = apply_wavelet(X)
    X = X.reshape(X.shape[0], -1)
    sample_data.append(X)
scaler.fit(np.vstack(sample_data))

# === WNN-Inspired MLP Model ===
class WaveletNet(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(WaveletNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        return self.net(x)

# === Model, Optimizer, Loss ===
sample_input = sample_data[0]
model = WaveletNet(input_dim=sample_input.shape[1], num_classes=len(classes)).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# === Accuracy Calculation Function ===
def get_accuracy(files):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for file in files:
            data = np.load(file)
            X, y = data['X'], data['y']
            X = apply_wavelet(X)
            X = X.reshape(X.shape[0], -1)
            X = scaler.transform(X)
            X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
            y_tensor = torch.tensor(y, dtype=torch.long).to(device)

            outputs = model(X_tensor)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == y_tensor).sum().item()
            total += y_tensor.size(0)
    return correct / total

# === Training Loop with Epoch-Wise Accuracy ===
for epoch in range(epochs):
    model.train()
    for file in train_files:
        data = np.load(file)
        X, y = data['X'], data['y']
        X = apply_wavelet(X)
        X = X.reshape(X.shape[0], -1)
        X = scaler.transform(X)

        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        y_tensor = torch.tensor(y, dtype=torch.long).to(device)

        for start in range(0, len(X_tensor), batch_size):
            end = start + batch_size
            X_batch = X_tensor[start:end]
            y_batch = y_tensor[start:end]

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

    # === Print Epoch Accuracy ===
    train_acc = get_accuracy(train_files)
    test_acc = get_accuracy(test_files)
    print(f"{epoch+1},{train_acc:.4f},{test_acc:.4f}")


1,0.5473,0.4879
2,0.5223,0.4986
3,0.5152,0.5243
4,0.5179,0.5036
5,0.5175,0.5043
6,0.5209,0.5171
7,0.5161,0.5014
8,0.5364,0.5071
9,0.5463,0.5007
10,0.5618,0.5036
11,0.5593,0.4986
12,0.5668,0.5114
13,0.5850,0.5114
14,0.5887,0.5050
15,0.6254,0.5129
16,0.6166,0.5021
17,0.6218,0.5007
