# Minimal SPDIM incorporating TSMNet demo notebook for inter-session/-subject source-free unsupervised domain adaptation (SFUDA) under label shifts

In [1]:
import torch
import sklearn
import pandas as pd
from copy import deepcopy
from moabb.datasets import BNCI2015_001
from moabb.paradigms import MotorImagery
from spdnets.dataloader import StratifiedDomainDataLoader, DomainDataset 
from spdnets.models import TSMNet
import spdnets.batchnorm as bn
import spdnets.functionals as fn
from spdnets.trainer import Trainer
from spdnets.callbacks import MomentumBatchNormScheduler, EarlyStopping

## Parameters for experiments
### Notice: define the evaluation setting (i.e., inter-session/inter-subject) and the label ratio (label shifts level in the target domain) here.
### We have provided pre-trained source models. If you wish to train the model from scratch, please set 'pretrained_model' to False. It usually takes 5/30 mins for inter-session/inter-subject to train on standard PCs with a single GPU.

In [2]:
# Network and training configuration
cfg = dict(
    # parameters for experiments
    epochs = 100,
    batch_size_train = 50,
    domains_per_batch = 5,
    validation_size = 0.2,
    evaluation = 'inter-session', # 'inter-subject' or 'inter-session'
    label_ratio = 0.2,        # we set 0.2 in the paper
    dtype = torch.float32,
    pretrained_model = True,
    # parameters for the TSMNet model
    mdl_kwargs = dict(
        temporal_filters=4,
        spatial_filters=40,
        subspacedims=20, 
        bnorm_dispersion=bn.BatchNormDispersion.SCALAR,
        spd_device='cpu',
        spd_dtype=torch.double,
        domain_adaptation=True
    )
)

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('GPU')
else:
    device = torch.device('cpu')
    print('CPU')

GPU


## load a MOABB dataset. 
### Notice: there is no need to manually download and preprocess the datasets. This is done automatically in MOABB pipeline

In [3]:
moabb_ds = BNCI2015_001()
n_classes = 2
moabb_paradigm = MotorImagery(n_classes=n_classes, events=['right_hand', 'feet'], fmin=4, fmax=36, tmin=1.0, tmax=4.0, resample=256)

## fit and evaluat the model for all domains

In [4]:
records = []

# Check the evaluation type in the configuration
if 'inter-session' in cfg['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()


# iterate over groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add dataset depended model kwargs
    mdl_kwargs = deepcopy(cfg['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg['validation_size'])
        train, val = next(cv_inner.split(X[fit], y[fit], cv_inner_group[fit]))

        # adjust number of domains if necessary
        du = domain[fit][train].unique()
        if cfg['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg['domains_per_batch']

        # get the label ratio , here source domain is balanced
        source_label_ratio, target_label_ratio = fn.get_label_ratio(y, cfg['label_ratio'])
        
        # split entire dataset into train/validation
        ds_train = DomainDataset(X[fit][train], y[fit][train], domain[fit][train],label_ratio=source_label_ratio)
        ds_val = DomainDataset(X[fit][val], y[fit][val], domain[fit][val], label_ratio=source_label_ratio) 

        # create dataloaders, for training use specific loader/sampler so that 
        # batches contain a specific number of domains with equal observations per domain and stratified labels       
        loader_train = StratifiedDomainDataLoader(ds_train, cfg['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))

        # create the model
        net = TSMNet(**mdl_kwargs).to(device=device, dtype=cfg['dtype'])

        # create the momentum scheduler and early stopping callback
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg['epochs']-10,
            bs0=cfg['batch_size_train'],
            bs=cfg['batch_size_train']/cfg['domains_per_batch'], 
            tau0=0.85
        )
        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=20, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg['epochs'],
            min_epochs=50,
            callbacks=[bn_sched, es],
            loss= torch.nn.CrossEntropyLoss(weight = None),
            device=device, 
            dtype=cfg['dtype']
        )

        # fit the model extract model parameters
        parameter_t = torch.tensor(1,dtype=torch.float64,device='cpu')

        if cfg['pretrained_model']:
            if cfg['evaluation'] == 'inter-session':
                state_dict = torch.load(f"pretrained_model/session/state_dict_{ix_subset}{ix_fold}.pt", map_location=device)
            elif cfg['evaluation'] == 'inter-subject':
                state_dict = torch.load(f"pretrained_model/subject/state_dict_{ix_fold}.pt", map_location=device)
        else:
            trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val,parameter_t=parameter_t)
            state_dict = deepcopy(net.state_dict())

        # create a new model for SFUDA
        sfuda_offline_net = TSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        test_domain=domain[test].unique()


        # Evaluate over test domains in the target domain 
        for test_domain in test_domain:
            if 'inter-session' in cfg['evaluation']:
                subject=ix_subset
            else:
                subject=ix_fold
            print(f"Subject:{subject}, test domain: {test_domain}")
            
            # create test dataset, and artificially introduce the label shifts
            ds_test = DomainDataset(X[test][domain[test] == test_domain], y[test][domain[test] == test_domain], domain[test][domain[test] == test_domain], label_ratio=target_label_ratio)
            loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))


            # enable SFUDA 
            sfuda_offline_net.eval()
            sfuda_offline_net.domainadapt_finetune(ds_test.features.to(dtype=cfg['dtype'], device=device), ds_test.labels.to(device=device), ds_test.domains, 'refit')

            # SFUDA method: RCT
            res = trainer.test(sfuda_offline_net, dataloader=loader_test,parameter_t=parameter_t)
            print('RCT',res)
            records.append(dict(mode='RCT',subject=subject,domain=test_domain, **res))

            # SFUDA method: clustering refined mean [Li et al. 2024, ESANN]
            refined_mean = trainer.get_refined_mean(sfuda_offline_net, test_dataloader=loader_test,parameter_t=parameter_t)
            res = trainer.test(sfuda_offline_net, dataloader=loader_test, parameter_t=parameter_t,fm_mean=refined_mean)
            print('clustering',res)
            records.append(dict(mode="clustering",subject=subject,domain=test_domain, **res))           




 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:0, test domain: 0
RCT {'loss': 0.1862202286720276, 'score': 0.96}
clustering {'loss': 0.00427457457408309, 'score': 1.0}
Subject:0, test domain: 1
RCT {'loss': 0.2871244549751282, 'score': 0.9450000000000001}




clustering {'loss': 0.044719304889440536, 'score': 0.99}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:1, test domain: 0
RCT {'loss': 0.2395249754190445, 'score': 0.9450000000000001}
clustering {'loss': 0.014276506379246712, 'score': 1.0}
Subject:1, test domain: 1
RCT {'loss': 0.3011021018028259, 'score': 0.94}




clustering {'loss': 0.03853554651141167, 'score': 0.97}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:2, test domain: 0
RCT {'loss': 0.6154170632362366, 'score': 0.855}
clustering {'loss': 0.3749459385871887, 'score': 0.895}
Subject:2, test domain: 1
RCT {'loss': 0.47351497411727905, 'score': 0.88}




clustering {'loss': 0.17627711594104767, 'score': 0.935}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:3, test domain: 0
RCT {'loss': 0.3927690386772156, 'score': 0.885}
clustering {'loss': 0.1394817978143692, 'score': 0.87}
Subject:3, test domain: 1
RCT {'loss': 0.45832502841949463, 'score': 0.855}




clustering {'loss': 0.20754128694534302, 'score': 0.87}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:4, test domain: 0
RCT {'loss': 0.6095286011695862, 'score': 0.815}
clustering {'loss': 0.5678255558013916, 'score': 0.8300000000000001}
Subject:4, test domain: 1
RCT {'loss': 0.6144850850105286, 'score': 0.835}




clustering {'loss': 0.5092177987098694, 'score': 0.865}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:5, test domain: 0
RCT {'loss': 0.7752129435539246, 'score': 0.7250000000000001}
clustering {'loss': 0.7302314639091492, 'score': 0.73}
Subject:5, test domain: 1
RCT {'loss': 0.6805852055549622, 'score': 0.74}




clustering {'loss': 0.7560980319976807, 'score': 0.74}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:6, test domain: 0
RCT {'loss': 0.5381815433502197, 'score': 0.86}
clustering {'loss': 0.2875145375728607, 'score': 0.9}
Subject:6, test domain: 1
RCT {'loss': 0.6461504697799683, 'score': 0.825}




clustering {'loss': 0.48352208733558655, 'score': 0.815}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:7, test domain: 0
RCT {'loss': 0.7675186991691589, 'score': 0.72}
clustering {'loss': 0.5146584510803223, 'score': 0.725}
Subject:7, test domain: 1
RCT {'loss': 0.6190028786659241, 'score': 0.74}




clustering {'loss': 0.617132842540741, 'score': 0.74}
Subject:7, test domain: 2
RCT {'loss': 0.6647598147392273, 'score': 0.8}




clustering {'loss': 0.6231288313865662, 'score': 0.815}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:8, test domain: 0
RCT {'loss': 1.079963207244873, 'score': 0.69}
clustering {'loss': 0.7985784411430359, 'score': 0.695}
Subject:8, test domain: 1
RCT {'loss': 0.6267551183700562, 'score': 0.7949999999999999}




clustering {'loss': 0.4248878061771393, 'score': 0.87}
Subject:8, test domain: 2
RCT {'loss': 0.49211204051971436, 'score': 0.85}




clustering {'loss': 0.33416491746902466, 'score': 0.9099999999999999}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:9, test domain: 0
RCT {'loss': 0.7714546322822571, 'score': 0.745}
clustering {'loss': 0.48742932081222534, 'score': 0.735}
Subject:9, test domain: 1
RCT {'loss': 0.8173813223838806, 'score': 0.735}




clustering {'loss': 0.7629995942115784, 'score': 0.7150000000000001}
Subject:9, test domain: 2
RCT {'loss': 0.930731475353241, 'score': 0.625}




clustering {'loss': 0.8972203135490417, 'score': 0.6499999999999999}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:10, test domain: 0
RCT {'loss': 0.6306249499320984, 'score': 0.7849999999999999}
clustering {'loss': 0.5648812055587769, 'score': 0.815}
Subject:10, test domain: 1
RCT {'loss': 0.7752155065536499, 'score': 0.74}




clustering {'loss': 0.7502878308296204, 'score': 0.75}
Subject:10, test domain: 2
RCT {'loss': 0.8560011982917786, 'score': 0.835}




clustering {'loss': 0.7330495119094849, 'score': 0.86}


 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")
 'right_hand': 100
 'feet': 100>
  warn(f"warnEpochs {epochs}")


Subject:11, test domain: 0
RCT {'loss': 0.9370726943016052, 'score': 0.735}
clustering {'loss': 1.1962510347366333, 'score': 0.69}
Subject:11, test domain: 1
RCT {'loss': 0.6992105841636658, 'score': 0.655}




clustering {'loss': 0.6393830180168152, 'score': 0.685}


In [5]:
resdf = pd.DataFrame(records)
resdf.groupby(['mode']).agg(['mean', 'std']).round(4)

Unnamed: 0_level_0,subject,subject,domain,domain,loss,loss,score,score
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std
mode,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
RCT,5.9286,3.4526,0.7143,0.7127,0.6245,0.2174,0.8041,0.0891
clustering,5.9286,3.4526,0.7143,0.7127,0.4885,0.2986,0.8238,0.1042
