# DATA LOAD

In [None]:
from utils.data_load import load_all_participants
from utils.cm_plot import plot_binary_confusion_matrix,plot_ternary_confusion_matrix,plot_continous_perclos
from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor
from sklearn.model_selection import GroupKFold, cross_val_predict
from sklearn.metrics import accuracy_score,mean_squared_error

all_participants_data = load_all_participants()

X = all_participants_data.drop(
    columns=["perclos", "quantized_perclos", "Participant Name"]
)
groups = all_participants_data["Participant Name"]
cv_loso = GroupKFold(n_splits=len(groups.unique()))

y_binary_encoded = (all_participants_data["perclos"] > 0.5).astype(int)
y_ternary_encoded = all_participants_data["quantized_perclos"]
y_continuous = all_participants_data["perclos"]



# LOADER + MODEL

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupKFold
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt


class SEEDViGDataset(Dataset):
    def __init__(self, X, y, is_regression=False):
        self.X = X
        self.y = y
        self.is_regression = is_regression

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_item = torch.tensor(self.X[idx], dtype=torch.float32)
        if self.is_regression:
            # For regression, keep labels as float
            y_item = torch.tensor(self.y[idx], dtype=torch.float32)
        else:
            # For classification, cast labels to long (integer)
            y_item = torch.tensor(self.y[idx], dtype=torch.long)
        return x_item, y_item


class SimpleCNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.classifier = nn.Linear(32, output_dim)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.features(x)
        x = x.squeeze(-1)
        x = self.classifier(x)
        return x

class SimpleTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, nhead=8, num_layers=3):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, 64)
        encoder_layer = nn.TransformerEncoderLayer(d_model=64, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(64, output_dim)

    def forward(self, x):
        x = x.unsqueeze(0)
        x = self.embedding(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=0)
        x = self.fc_out(x)
        return x


# TRAINING LOOP


In [None]:

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device, task_type):
    model.eval()
    preds = []
    truths = []
    total_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item()
            if task_type == "regression":
                preds.append(outputs.cpu().numpy().flatten())
                truths.append(y_batch.cpu().numpy().flatten())
            else:
                # For classification, take argmax of logits
                preds.append(outputs.argmax(dim=1).cpu().numpy())
                truths.append(y_batch.cpu().numpy())
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return total_loss / len(dataloader), preds, truths


def run_pytorch_model(X, y, groups, model_type="cnn", task_type="binary", epochs=5):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if task_type == "binary":
        output_dim = 2
        is_regression = False
        criterion = nn.CrossEntropyLoss()
    elif task_type == "ternary":
        output_dim = 3
        is_regression = False
        criterion = nn.CrossEntropyLoss()
    else:
        output_dim = 1
        is_regression = True
        criterion = nn.MSELoss()

    group_kfold = GroupKFold(n_splits=len(np.unique(groups)))
    splits = list(group_kfold.split(X, y, groups=groups))

    all_preds = []
    all_truths = []

    for fold_idx, (train_idx, test_idx) in enumerate(splits, start=1):
        print(f"\nStarting fold {fold_idx}/{len(splits)}")

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        train_dataset = SEEDViGDataset(X_train, y_train, is_regression=is_regression)
        test_dataset = SEEDViGDataset(X_test, y_test, is_regression=is_regression)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        input_dim = X_train.shape[1]
        if model_type == "cnn":
            model = SimpleCNN(input_dim, output_dim)
        else:
            model = SimpleTransformer(input_dim, output_dim)

        model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

        for epoch in tqdm(range(epochs), desc=f"Fold {fold_idx} training"):
            train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)

        test_loss, preds, truths = evaluate(model, test_loader, criterion, device, task_type)
        all_preds.append(preds)
        all_truths.append(truths)

    all_preds = np.concatenate(all_preds)
    all_truths = np.concatenate(all_truths)

    if task_type == "regression":
        mse = np.mean((all_preds - all_truths) ** 2)
        rmse = np.sqrt(mse)
        print("Regression RMSE:", rmse)
    else:
        acc = accuracy_score(all_truths, all_preds)
        print("Classification Accuracy:", acc)
        cm = confusion_matrix(all_truths, all_preds)
        print("Confusion Matrix:\n", cm)
    

# MAIN

In [None]:
# Binary classification 
run_pytorch_model(X.values, y_binary_encoded.values, groups.values, model_type="transformer", task_type="binary", epochs=3)

In [None]:
# Ternary classification 
run_pytorch_model(X.values, y_ternary_encoded.values, groups.values, model_type="transformer", task_type="ternary", epochs=5)

In [None]:

# Continuous regression 
run_pytorch_model(X.values, y_continuous.values, groups.values, model_type="transformer", task_type="regression", epochs=5)