# Minimal TSMNet demo notebook for inter-session/-subject source-free (SF) offline and online unsupervised domain adaptation (UDA)

In [None]:
import torch
import sklearn
import pandas as pd


from moabb.datasets import Kalunga2016, MAMEM1, MAMEM2, MAMEM3, Lee2019_SSVEP
from moabb.paradigms import SSVEP


# from library.utils.torch import StratifiedDomainDataLoader
from spdnets.utils.data import StratifiedDomainDataLoader, DomainDataset

from spdnets.trainer import Trainer
from spdnets.callbacks import EarlyStopping, MomentumBatchNormScheduler
from spdnets.model import GyroNet
import os
import numpy as np
from scipy.linalg import sqrtm, inv 
import spdnets.functionals as fn
from copy import deepcopy
import spdnets.batchnorm as bn

In [None]:
def euler_align(raw_array):
    """
    Perform Euler alignment on input numpy array ([trials * channels * samples]) 
    and return the aligned array.
    """
    # Calculate mean covariance matrix
    cov_matrices = [np.cov(trial, rowvar=True) for trial in raw_array]
    #cov_matrices = np.cov(raw_array,axis=0, rowvar=True)
    mean_cov_matrix = np.mean(cov_matrices, axis=0)
    
    # Compute transformation matrix
    trans_matrix = inv(sqrtm(mean_cov_matrix))
    
    # Apply transformation to all trials using broadcasting
    return trans_matrix @ raw_array

In [None]:
# network and training configuration
cfg = dict(
    epochs = 100,
    batch_size_train = 50,
    domains_per_batch = 5,
    validation_size = 0.2,
    evaluation = 'inter-subject', # or 'inter-subject'
    dtype = torch.float64,
    training=True, 
    swd_loss_weight=0,
    euler_align=False,
    lr=0.001,
    weight_decay=1e-4,
    mdl_kwargs = dict( 
    bnorm_dispersion=bn.BatchNormDispersion.SCALAR,
    domain_adaptation=True
)
)

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print('GPU')
else:
    device = torch.device('cpu')
    print('CPU')

In [None]:
def split_eeg_data(X, labels, metadata, original_duration=2, target_duration=1):
    n_trials, n_channels, n_samples = X.shape
    
    # 计算采样率（假设数据是均匀采样的）
    sampling_rate = n_samples / original_duration
    target_samples = int(target_duration * sampling_rate)
    
    # 计算每个trial可以拆分成多少个片段
    n_segments = int(original_duration / target_duration)
    
    # 初始化拆分后的数据
    X_split = []
    labels_split = []
    metadata_split = []
    
    for i in range(n_trials):
        for j in range(n_segments):
            start_idx = j * target_samples
            end_idx = start_idx + target_samples
            
            # 提取数据片段
            segment = X[i, :, start_idx:end_idx]
            X_split.append(segment)
            
            # 复制标签（同一个trial的所有片段标签相同）
            labels_split.append(labels[i])
            
            # 复制并修改元数据
            meta_row = metadata.iloc[i].copy()
            meta_row['original_trial_idx'] = i
            meta_row['segment_idx'] = j
            meta_row['trial_segment'] = f"{i}_{j}"
            metadata_split.append(meta_row)
    
    X_split = np.array(X_split)
    labels_split = np.array(labels_split)
    metadata_split = pd.DataFrame(metadata_split).reset_index(drop=True)
    
    return X_split, labels_split, metadata_split

In [None]:
def sfuda_offline(dataset : DomainDataset, model : GyroNet):
    model.eval()
    model.domainadapt_finetune(dataset.features.to(dtype=cfg['dtype'], device=device), dataset.labels.to(device=device), dataset.domains, None)

## load a MOABB dataset

In [None]:
moabb_ds = Lee2019_SSVEP()
n_classes = 4
dataset='Lee2019_SSVEP'
original_duration =4
target_duration = 1
moabb_paradigm = SSVEP(n_classes=n_classes, fmin=1, fmax=50, resample=256)

## fit and evaluat the model for all domains

In [None]:
random_seed=42
torch.manual_seed(random_seed)
model_name = 'GyroNet'
for swd_weight in []:
    for ea_align in [0]:
        for lr in ]:
            records = []
            records1=[]
            if 'inter-session' in cfg['evaluation']:
                subset_iter = iter([[s] for s in moabb_ds.subject_list])
                groupvarname = 'session'
                eval_fashion = 'session'
            elif 'inter-subject' in cfg['evaluation']:
                subset_iter = iter([None])
                groupvarname = 'subject'
                eval_fashion = 'subject'
            else:
                raise NotImplementedError()
            location = f'swd_weight_{swd_weight}_EA_{ea_align}_lr_{lr}'
            index = 0
            while os.path.exists(f'pretrained/{dataset}/{location}/{index}'):
                index += 1
            # iterate over subject groups
            global_manifold_violation_counters = {}
            global_total_forward_calls = 0
            
            for ix_subset, subjects in enumerate(subset_iter):

                # get the data from the MOABB paradigm/dataset
                X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
                X, labels, metadata = split_eeg_data(X, labels, metadata)


                # extract domains = subject/session
                metadata['label'] = labels
                metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
                domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

                # convert to torch tensors
                domain = torch.from_numpy(domain)
                if ea_align ==1:
                    for i in domain.unique():
                        X[domain == i] = euler_align(X[domain == i])
                    input_align = 'ea'
                else:
                    input_align = 'none'
                # for du in domain.unique():
                #     domain_ixs = domain == du
                #     X[domain_ixs] = fn.robust_zscore(X[domain_ixs], per_channel_variance=False, axis=-1)
                # X = np.clip(X, -5, 5)
                if lr >= 0.01:
                    weight_decay = 1e-3
                else:
                    weight_decay = 1e-4  


                X = torch.from_numpy(X)
                y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
                y = torch.from_numpy(y)

                lab,lab_count = y.unique(return_counts=True)    

                # leave one subject or session out
                if 'inter-session' in cfg['evaluation']:
                    domain_count = len(domain.unique())
                elif 'inter-subject' in cfg['evaluation']:
                    domain_count = metadata['subject'].nunique()
                if domain_count <10:
                    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
                else:
                    cv_outer = sklearn.model_selection.GroupKFold(n_splits=10)
                cv_outer_group = metadata[groupvarname]

                # train/validation split stratified across domains and labels
                cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
                cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

                # add datadependen model kwargs
                mdl_kwargs = deepcopy(cfg['mdl_kwargs'])
                mdl_kwargs['num_classes'] = n_classes
                mdl_kwargs['num_electrodes'] = X.shape[1]
                mdl_kwargs['chunk_size'] = X.shape[2]
                mdl_kwargs['domains'] = domain.unique()
                # perform outer CV
                for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):
                
                    # split fitting data into train and validation 
                    cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg['validation_size'])
                    train, val = next(cv_inner.split(X[fit], y[fit], cv_inner_group[fit]))

                    # adjust number of 
                    du = domain[fit][train].unique()
                    if cfg['domains_per_batch'] > len(du):
                        domains_per_batch = len(du)
                    else:
                        domains_per_batch = cfg['domains_per_batch']

                    # split entire dataset into train/validation/test
                    ds_train = DomainDataset(X[fit][train], y[fit][train], domain[fit][train])
                    ds_val = DomainDataset(X[fit][val], y[fit][val], domain[fit][val])
                    

                    # create dataloaders
                    # for training use specific loader/sampler so taht 
                    # batches contain a specific number of domains with equal observations per domain
                    # and stratified labels
                    loader_train = StratifiedDomainDataLoader(ds_train, cfg['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True, drop_last = False)
                    #loader_train= torch.torch.utils.data.DataLoader(ds_train,50)
                    loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
                    test_domain = metadata['domain'].iloc[test].unique()

                    # create the model

                    model = GyroNet(device=device,dtype=cfg['dtype'],**mdl_kwargs).to(device=device, dtype=cfg['dtype'])
                    es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=20, verbose=False)
                    
                    bn_sched = MomentumBatchNormScheduler(
                        epochs=cfg['epochs']-10,
                        bs0=cfg['batch_size_train'],
                        bs=cfg['batch_size_train']/cfg['domains_per_batch'], 
                        tau0=0.85
                    )

                        # create the trainer
                    trainer = Trainer(
                        max_epochs= cfg['epochs'],
                        min_epochs= 70,
                        callbacks=[es,bn_sched],
                        loss=torch.nn.CrossEntropyLoss(),
                        device=device, 
                        dtype=torch.float64,
                        swd_weight=swd_weight,
                        lr=lr,
                        weight_decay=weight_decay
                    )
                    # fit the modelzz

                    # print parameters
                    if cfg['training']:
                        save_dir = f'pretrained/{dataset}/{location}/{index}'
                        trainer.fit(model, train_dataloader=loader_train, val_dataloader=loader_val)
                        model.print_manifold_violation_stats()
                        os.makedirs(save_dir, exist_ok=True)
                        torch.save(model, f'pretrained/{dataset}/{location}/{index}/{ix_subset}_{ix_fold}.pt')
                        print(f'ES best epoch={es.best_epoch}')
                        res = trainer.test(model, dataloader=loader_train)
                        records1.append(dict(input=input_align,latent='gyro',lr=lr,swd_weight=swd_weight,model=model_name,mode='train',dataset=dataset,**res))
                        res = trainer.test(model, dataloader=loader_val)
                        records1.append(dict(input=input_align,latent='gyro',lr=lr,swd_weight=swd_weight,model=model_name,mode='validation',dataset=dataset,**res))
                    else:
                        pass
                    sfuda_offline_net=torch.load(f'pretrained/{dataset}/{location}/{index}/{ix_subset}_{ix_fold}.pt', map_location=device, weights_only=False)


                    global_total_forward_calls += model.total_forward_calls
                    for layer_name, violation_count in model.manifold_violation_counters.items():
                        if layer_name not in global_manifold_violation_counters:
                            global_manifold_violation_counters[layer_name] = 0
                        global_manifold_violation_counters[layer_name] += violation_count
                    # evaluation

                    test_domain=domain[test].unique()
                    for test_domain in test_domain:
                        if 'inter-session' in cfg['evaluation']:
                            subject=ix_subset
                        else:
                            subject=ix_fold
                        print(f"Subject:{subject}, test domain: {test_domain}")    
                        ds_test = DomainDataset(X[test][domain[test] == test_domain], y[test][domain[test] == test_domain], domain[test][domain[test] == test_domain])
                        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))
                        sfuda_offline(ds_test, sfuda_offline_net)
                        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
                        res2 = res
                        print(f"{model_name},Test results: {res2}")
                        records.append(dict(input=input_align,latent='gyro',lr=lr,swd_weight=swd_weight,model=model_name,dataset=dataset,subject=subject,domain=test_domain, **res))

            # save records
            resdf = pd.DataFrame(records)
            resdf.to_csv(f'pretrained/{dataset}/results_{location}_{index}.csv',index=False)
            resdf1 = pd.DataFrame(records1)
            resdf1.to_csv(f'pretrained/{dataset}/train_val_results_{location}_{index}.csv',index=False)
