In [39]:
import os
import sys

import torch
import torch.nn as nn
import torch.optim as optim
import torchcde

import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import json

from lsde_model import LSDE
from lnsde_model import LNSDE
from gsde_model import GSDE
from sde_model import SDE
from ode_model import ODE
from cde_model import CDE
from rnn_model import RNN

from collections import defaultdict, Counter
from sklearn.model_selection import StratifiedGroupKFold
import numpy as np
from torch.utils.data import DataLoader, SubsetRandomSampler

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from utilities import collate_fn, create_dataloaders, EarlyStopping

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Constants

In [41]:
train_size = 0.7
val_size = 0.15

input_dim = 3 # [t, x, y]
output_dim = 2 # [x, y]
n_classes = 8

num_epochs = 1000
val_every = 1

criterion = nn.CrossEntropyLoss()

### Train and test methods

In [42]:
def process_batch(batch):
    ts = torch.linspace(0, 1, batch.size(1), device=device) # Time steps
    ts_expanded = ts.unsqueeze(0).unsqueeze(-1).expand(batch.size(0), -1, -1) # (batch_size, seq_len, 1)
    batch = torch.cat((ts_expanded, batch), dim=-1) # Concatenate time steps and batch coordinates (batch_size, seq_len, 3)

        # Coefficients of a continuous representation of discrete data
    coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(batch, ts)
    return batch, ts, coeffs

In [43]:
def train_val_step(model, dataloader, optimizer, scheduler, train=True, rnn=False):
    if train: model.train()
    else: model.eval()

    epoch_loss = 0
    correct_fix_preds = 0
    n_fixs = 0
    for sbj_ids, _, batch, mask in dataloader:
        batch = batch.to(device)
        sbj_ids = sbj_ids.to(device)
        mask = mask.to(device)

        batch, ts, coeffs = process_batch(batch)

        if train: optimizer.zero_grad()

        if rnn:
            _, logits = model(batch, mask)
        else:
            _, logits = model(coeffs, ts, mask)
        loss = criterion(logits, sbj_ids)
        epoch_loss += loss.item()

        if train:
            loss.backward()
            optimizer.step()

        preds = logits.argmax(dim=1)
        correct_fix_preds += (preds == sbj_ids).sum().item()
        n_fixs += len(sbj_ids)

    epoch_loss /= len(dataloader)
    if not train: scheduler.step(epoch_loss)
    epoch_accuracy = correct_fix_preds / n_fixs
    return epoch_loss, epoch_accuracy

In [44]:
def train(model, train_loader, val_loader, optimizer, scheduler, early_stopping, model_name, rnn=False):
    os.makedirs(model_name.split('_')[0], exist_ok=True)
    train_losses = []
    val_losses = []
    best_epoch = 0
    best_val_loss = float("inf")
    best_accuracy = 0.0

    epoch_bar = tqdm(range(1, num_epochs + 1), desc="Epochs")
    for epoch in epoch_bar:
        # Train step
        epoch_train_loss, _ = train_val_step(model, train_loader, optimizer, scheduler, train=True, rnn=rnn)
        train_losses.append(epoch_train_loss)

        # Val step
        if epoch % val_every == 0:
            epoch_val_loss, epoch_accuracy = train_val_step(model, val_loader, optimizer, scheduler, train=False, rnn=rnn)
            val_losses.append(epoch_val_loss)

            # Save best model
            if epoch_val_loss < best_val_loss:
                best_epoch = epoch
                best_val_loss = epoch_val_loss
                best_accuracy = epoch_accuracy
                epoch_bar.set_postfix({"Best Epoch": best_epoch, "Best Val Loss": round(best_val_loss, 3), "Accuracy": round(best_accuracy, 3)})
                torch.save({
                    'model': model.state_dict(),
                    'epoch': best_epoch,
                    'val_loss': best_val_loss,
                    'val_accuracy': best_accuracy
                }, f"{model_name.split('_')[0]}/{model_name}_best_model.pth")

            early_stopping(epoch_val_loss)
            if early_stopping.early_stop:
                print(f"\tEarly stopping triggered at epoch {epoch}")
                break

    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.legend()
    plt.savefig(f"{model_name.split('_')[0]}/{model_name}_losses.png")
    plt.clf()

    return best_epoch, best_val_loss, best_accuracy 

In [45]:
def test(model, test_loader, rnn=False):
    model.eval()

    scanpath_preds = defaultdict(list)  # scanpath_id -> [pred1, pred2, ...]
    scanpath_labels = {}                # scanpath_id -> label

    correct_fix_preds = 0
    n_fixs = 0
    with torch.no_grad():
        for sbj_ids, scanpaths_ids, batch, mask in test_loader:
            batch = batch.to(device)
            sbj_ids = sbj_ids.to(device)
            mask = mask.to(device)

            batch, ts, coeffs = process_batch(batch)

            if rnn:
                _, logits = model(batch, mask)
            else:
                _, logits = model(coeffs, ts, mask)
            preds = logits.argmax(dim=1)
            correct_fix_preds += (preds == sbj_ids).sum().item()
            n_fixs += len(sbj_ids)

            for scanpath_id, true, pred in zip(scanpaths_ids, sbj_ids, preds.cpu().tolist()):
                scanpath_id = int(scanpath_id)
                scanpath_preds[scanpath_id].append(int(pred))
                scanpath_labels[scanpath_id] = int(true)

    # Majority voting
    correct_scanpath_preds = 0
    for scanpath_id, preds in scanpath_preds.items():
        majority_pred = Counter(preds).most_common(1)[0][0]
        if majority_pred == scanpath_labels.get(scanpath_id):
            correct_scanpath_preds += 1

    fix_accuracy = correct_fix_preds / len(test_loader.dataset)
    scanpath_accuracy = correct_scanpath_preds / len(scanpath_preds)

    print(f"Fixations classification: {correct_fix_preds}/{n_fixs} -> {fix_accuracy:.3f}")
    print(f"Scanpath classification: {correct_scanpath_preds}/{len(scanpath_preds)} -> {scanpath_accuracy:.3f}")

    return fix_accuracy, scanpath_accuracy

In [46]:
def grid_search(data, data_type):
    batch_sizes = [32, 64, 128]
    hidden_dims = [32, 64, 128]

    os.makedirs(f"{data_type}_models", exist_ok=True)
    os.chdir(f"{data_type}_models")

    for batch_size in batch_sizes:

        train_loader, val_loader, test_loader = create_dataloaders(data, train_size, val_size, batch_size, data_type)

        for hidden_dim in hidden_dims:

            models = {
                "rnn": RNN(input_dim, hidden_dim, output_dim, n_classes, rnn_type="rnn").to(device),
                "gru": RNN(input_dim, hidden_dim, output_dim, n_classes, rnn_type="gru").to(device),
                "lstm": RNN(input_dim, hidden_dim, output_dim, n_classes, rnn_type="lstm").to(device),
                "ncde": CDE(input_dim, hidden_dim, output_dim, n_classes).to(device),
                "node": ODE(input_dim, hidden_dim, output_dim, n_classes).to(device),
                "nsde": SDE(input_dim, hidden_dim, output_dim, n_classes).to(device),
                "lsde": LSDE(input_dim, hidden_dim, output_dim, n_classes).to(device),
                "lnsde": LNSDE(input_dim, hidden_dim, output_dim, n_classes).to(device),
                "gsde": GSDE(input_dim, hidden_dim, output_dim, n_classes).to(device)
            }

            for model_name, model in models.items():
                os.makedirs(model_name, exist_ok=True)
                optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10)
                early_stopping = EarlyStopping(patience=20, delta=1e-3)

                tag = f"{model_name}_bs{batch_size}_hd{hidden_dim}"

                print(f"Training {tag}...")
                rnn = model_name in {'rnn', 'gru', 'lstm'}
                best_epoch, best_val_loss, best_accuracy = train(model, train_loader, val_loader,
                                                                 optimizer, scheduler, early_stopping,
                                                                 tag, rnn)
                print()

                print(f"Testing {tag}...")
                checkpoint = torch.load(f"{model_name}/{tag}_best_model.pth")
                model.load_state_dict(checkpoint['model'])
                test_fix_accuracy, test_scanpath_accuracy = test(model, test_loader, rnn)
                print()

                stats = {
                    "tag": tag,
                    "epoch": best_epoch,
                    "val_loss": best_val_loss,
                    "val_accuracy": best_accuracy,
                    "test_accuracy": test_fix_accuracy,
                    "scanpath_accuracy": test_scanpath_accuracy,
                }

                with open(f"{model_name}/{tag}_stats.json", "w") as f:
                    json.dump(stats, f, indent=2)

                print()

    os.chdir("..")

In [None]:
def cross_validation(model_name, data, batch_size, hidden_dim, k_folds=5):
    rnn = model_name in {'rnn', 'gru', 'lstm'}

    epochs = []
    losses = []
    fix_accuracies = []
    scanpath_accuracies = []

    n_scanpaths = len(data)
    scanpath_ids = np.arange(n_scanpaths)
    scanpath_labels = np.array([sbj_id for sbj_id, _, _, _ in data])

    cv = StratifiedGroupKFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_sp_idxs, test_sp_idxs) in enumerate(cv.split(
        X=np.zeros(n_scanpaths),
        y=scanpath_labels,
        groups=scanpath_ids)):

        print(f'Fold {fold + 1}')
        model = LSDE(input_dim, hidden_dim, output_dim, n_classes).to(device)

        train_fix_data = [
            (sbj_id, sp_id, fix)
            for sp_id in train_sp_idxs
            for sbj_id, _, fixs, _ in [data[sp_id]]
            for fix in fixs
        ]

        test_fix_data = [
            (sbj_id, sp_id, fix)
            for sp_id in test_sp_idxs
            for sbj_id, _, fixs, _ in [data[sp_id]]
            for fix in fixs
        ]
        
        train_loader = DataLoader(train_fix_data, batch_size, shuffle=True, collate_fn=collate_fn)
        test_loader = DataLoader(test_fix_data, batch_size, shuffle=False, collate_fn=collate_fn)

        optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10)
        early_stopping = EarlyStopping(patience=20, delta=1e-3)

        best_epoch, best_loss, best_accuracy = train(model, train_loader, test_loader,
                                                        optimizer, scheduler, early_stopping,
                                                        model_name, rnn)
        epochs.append(best_epoch)
        losses.append(best_loss)
        fix_accuracies.append(best_accuracy)

        checkpoint = torch.load(f"{model_name}/{model_name}_best_model.pth")
        model.load_state_dict(checkpoint['model'])
        _, test_scanpath_accuracy = test(model, test_loader, rnn)
        scanpath_accuracies.append(test_scanpath_accuracy)

        print()

    print("Cross validation averages:")
    print("\tEpoch: ", np.mean(epochs))
    print("\tVal Loss: ", np.mean(losses))
    print("\tVal Fix Accuracy: ", np.mean(fix_accuracies))
    print("\tTest Scanpath Accuracy: ", np.mean(scanpath_accuracies))

### Main

In [None]:
with open("scanpaths_fixs_sacs.pkl", "rb") as f:
    data = pickle.load(f) # list of tuples (sbj_id, scanpath, fixs_list, sacs_list)

### Cross validation

In [None]:
batch_size = 64
hidden_dim = 64
222
model_name = "lsde"
cross_validation(model_name, data, batch_size, hidden_dim, k_folds=5)

Fold 1


Epochs:  12%|█▏        | 117/1000 [4:17:22<32:22:21, 131.98s/it, Best Epoch=98, Best Val Loss=1.41, Accuracy=0.475]

	Early stopping triggered at epoch 118





Fixations classification: 1029/2179 -> 0.472
Scanpath classification: 213/320 -> 0.666

Fold 2


Epochs:  21%|██▏       | 213/1000 [7:44:50<28:37:32, 130.94s/it, Best Epoch=194, Best Val Loss=1.29, Accuracy=0.537]

	Early stopping triggered at epoch 214





Fixations classification: 1175/2184 -> 0.538
Scanpath classification: 241/320 -> 0.753

Fold 3


Epochs:  12%|█▏        | 121/1000 [4:24:08<31:58:48, 130.98s/it, Best Epoch=102, Best Val Loss=1.41, Accuracy=0.492]

	Early stopping triggered at epoch 122





Fixations classification: 1072/2179 -> 0.492
Scanpath classification: 202/320 -> 0.631

Fold 4


Epochs:  14%|█▎        | 136/1000 [4:57:17<31:28:41, 131.16s/it, Best Epoch=117, Best Val Loss=1.32, Accuracy=0.5]  

	Early stopping triggered at epoch 137





Fixations classification: 1092/2193 -> 0.498
Scanpath classification: 220/320 -> 0.688

Fold 5


Epochs:  16%|█▌        | 155/1000 [5:37:52<30:41:59, 130.79s/it, Best Epoch=136, Best Val Loss=1.3, Accuracy=0.534] 

	Early stopping triggered at epoch 156





Fixations classification: 1176/2194 -> 0.536
Scanpath classification: 234/320 -> 0.731

Cross validation averages:
	Epoch:  129.4
	Val Loss:  1.345902970177787
	Val Fix Accuracy:  0.5075093268967994
	Test Fix Accuracy:  0.5072325470459766
	Test Scanpath Accuracy:  0.69375


<Figure size 640x480 with 0 Axes>

### Grid search

In [None]:
grid_search(data, "fix")

### Test trained model

In [None]:
model_path = "lsde_bs64_hd128_best_model.pth"
checkpoint = torch.load(model_path)
print(checkpoint['epoch'], checkpoint['val_loss'], checkpoint['val_accuracy'])

batch_size = 64
hidden_dim = 64
model = LSDE(input_dim, hidden_dim, output_dim, n_classes).to(device)
model.load_state_dict(checkpoint['model'])

train_loader, val_loader, test_loader = create_dataloaders(data, train_size, val_size, batch_size, "fix")
test(model, val_loader, rnn=False)
test(model, test_loader, rnn=False)