In [10]:
import json
from pathlib import Path

import numpy as np

from ezmsg.util.messagelogger import MessageDecoder

from hololight.shallowfbcspnet import ShallowFBCSPNet, get_output_shape

from typing import List


In [11]:

samples = []
data_fname = Path( '..' ) / 'recordings' / 'traindata.txt'
with open( data_fname, 'r' ) as data_file:
    samples = [ json.loads( line, cls = MessageDecoder ) for line in data_file ]

eeg_trials = np.array( [ s[ 'sample' ][ 'data' ] for s in samples ] )
eeg_labels = np.array( [ s[ 'trigger' ][ 'value' ] for s in samples ] )

In [12]:
net = ShallowFBCSPNet(
    in_chans = 8,
    n_classes = 2,
    input_time_length = 1000,
    cropped_training = True,
    n_filters_time = 40,
    filter_time_length = 25,
    n_filters_spat = 40,
    pool_time_length = 75,
    pool_time_stride = 15,
    conv_nonlin = 'square',
    pool_mode = 'mean',
    pool_nonlin = 'safe_log',
    split_first_layer = True,
    batch_norm = True,
    batch_norm_alpha = 0.1,
    drop_prob = 0.5,
)

model = net.construct()
print( model )
print( get_output_shape( 
    model, 
    in_chans = net.in_chans, 
    input_window_samples = 700
) )

467
Sequential(
  (ensuredims): Ensure4d()
  (dimshuffle): Expression(expression=_transpose_time_to_spat)
  (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
  (conv_spat): Conv2d(40, 40, kernel_size=(1, 8), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=square)
  (pool): AvgPool2d(kernel_size=(75, 1), stride=(1, 1), padding=0)
  (pool_nonlin): Expression(expression=safe_log)
  (drop): Dropout(p=0.5, inplace=False)
  (conv_classifier): Conv2d(40, 2, kernel_size=(30, 1), stride=(1, 1), dilation=(15, 1))
  (softmax): LogSoftmax(dim=1)
  (squeeze): Expression(expression=_squeeze_final_output)
)
torch.Size([1, 2, 167])


In [None]:
import warnings

import torch
from torch.utils.data import (
    Dataset,
    DataLoader, 
    random_split, 
    RandomSampler,
    WeightedRandomSampler
)
from torch.optim.lr_scheduler import CosineAnnealingLR

from shallowfbcspnet import ShallowFBCSPNet

from typing import (
    Optional,
    List,
    Tuple
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print( f"Using Device: {device}" )

if device == 'cuda':
    cuda_id = 1 # torch.cuda.current_device()
    torch.cuda.set_device( cuda_id )
    print( f"ID of current CUDA device:{ torch.cuda.current_device() }" )
    print( f"Name of current CUDA device: { torch.cuda.get_device_name( cuda_id ) }" )

In [None]:
from dataclasses import dataclass

@dataclass
class FBCSPDataset( Dataset ):
    """ 
    This is a wrapper around a xr.DataArray to generate a multi-class problem 
    Given a data array of epochs with a --
        * trial dimension upon which labels exist
        * feat/channel dimension,
        * time dimension
        
    This dataset allows you to iterate over the dataset and produces one-indexed
    labels corresponding to self.label_dict
    
    The idea is to provide trials that are centered around some t_0 = 0 corresponding 
    to a trial onset and a stretch of data before t_0 = 0 corresponds to a null event
    which is assigned a label of 0.  As such, len( dataset ) will actually result in
    len( da[ trial_dim ] ) * 2.  Every other trial is a null trial
    """
    
    da: xr.DataArray # 3D DataArray incl. feat_dim, trial_dim, and time_dim
    feat_dim: str = 'feat' # Features dimension (You may need to stack some dims to make this)
    trial_dim: str = 'event' # Trial dimension - Should be the dimension along which labels exist
    time_dim: str = 'time' # Time dimension -- each epoch should be a short stretch of time
    label_coord: str = 'label' # Coordinate on trial_dim that contains labels
    t_min: float = -0.01 # Time labels to sub-select the input data such that ---
    t_max: float = 0.05 # --- trials.sel( time = slice( t_min, t_max ) ) => condition (one-indexed label)
    add_null_class: bool = True # If True, trials.sel( time = slice( -t_max, -t_min ) ) => null (0)
    crops: int = 2 # Perform data augmentation by using sliding windows across epochs

    def __post_init__( self ):
        self.da = self.da.transpose( self.trial_dim, self.feat_dim, self.time_dim )
        self.conditions = self.da.sel( time = slice( self.t_min, self.t_max ) )
        self.n_time = len( self.conditions[ self.time_dim ] )
        self.n_feat = len( self.conditions[ self.feat_dim ] )
        
        self.label_dict = { label: idx for idx, label in enumerate( np.unique( self.da[ self.label_coord ] ) ) }
        
        if self.add_null_class:
            null_t_min_idx = np.abs( self.da[ self.time_dim ] - ( -self.t_min ) ).argmin( self.time_dim ).item()
            self.nulls = self.da.isel( time = slice( null_t_min_idx - self.n_time, null_t_min_idx ) )
            self.label_dict = { l: ( i + 1 ) for l, i in self.label_dict.items() }
            self.label_dict[ 'null' ] = 0
            
        if self.crops > 0:
            self.n_time = self.n_time - self.crops
        
        self.prefetch_label = []
        self.prefetch_data = []
        for idx in range( len( self ) ):
    
            trial_source = self.conditions
            if self.add_null_class:
                trial_source: xr.DataArray = self.conditions if ( idx % 2 ) == 0 else self.nulls
                idx: int = int( idx // 2 )

            crop_idx = 0
            if self.crops > 0:
                crop_idx = idx % ( self.crops + 1 )
                idx = int( idx // ( self.crops + 1 ) )

            trial = trial_source.isel( { 
                self.trial_dim: idx, 
                self.time_dim: slice( crop_idx, crop_idx + self.n_time ) 
            } )

            label = 0
            if trial_source is self.conditions:
                label = self.label_dict[ trial[ self.label_coord ].item() ]
            self.prefetch_label.append( torch.tensor( label ) )
            self.prefetch_data.append( torch.tensor( trial.values.astype( np.float32 ) ) )
            
        self.prefetch_data = torch.cat( [ d[ None, ... ] for d in self.prefetch_data ] )
        self.prefetch_label = torch.tensor( self.prefetch_label ) 
                        
    def __len__( self ) -> int:
        n = len( self.da[ self.trial_dim ] )
        n = n * ( 2 if self.add_null_class else 1 )
        if self.crops > 0:
            n = n * ( self.crops + 1 )
        return n

    def __getitem__( self, idx: int ) -> Tuple[ torch.tensor, int ]:           
        return ( 
            self.prefetch_data[ idx, ... ], 
            self.prefetch_label[ idx, ... ] 
        )
    
feats = filt_phase_vel
feats = feats \
    .stack( feat = [ 'x', 'y' ] ) \
#     .sel( event = filt_pos.label != 'C1' ) \


train_ratio = 0.75
train_da, test_da = [], []
for label, label_da in feats.groupby( 'label' ):
    train_events = int( train_ratio * len( label_da.event ) )
    indices = np.arange( len( label_da.event ) )
    np.random.shuffle( indices )
    train_indices = indices[ :train_events ]
    test_indices = indices[ train_events: ]
    train_da.append( label_da.isel( event = train_indices ) )
    test_da.append( label_da.isel( event = test_indices ) )
train_da = xr.concat( train_da, 'event' )
test_da = xr.concat( test_da, 'event' )

# We have to split this way because crops can contaminate test dataset if split randomly hereafter
train_dset = FBCSPDataset( train_da, add_null_class = True )
test_dset = FBCSPDataset( test_da, add_null_class = True )

print( f'Train dset: { len( train_dset ) } examples of { train_dset.n_feat } pixels x { train_dset.n_time } time points' )
for lab, count in zip( *torch.unique( train_dset.prefetch_label, return_counts = True ) ):
    print( f'\t*{ count } { lab.item() } examples' )
    
print( f'Test dset: { len( test_dset ) } examples of { test_dset.n_feat } pixels x { test_dset.n_time } time points' )
for lab, count in zip( *torch.unique( test_dset.prefetch_label, return_counts = True ) ):
    print( f'\t*{ count } { lab.item() } examples' )
    
print( train_dset.label_dict )
print( test_dset.label_dict )

In [None]:
learning_rate = 0.0001
max_epochs = 300
batch_size = 32
weight_decay = 0.0

model_definition = ShallowFBCSPNet( 
    train_dset.n_feat, 
    len( train_dset.label_dict ), 
    input_time_length = train_dset.n_time, 
    final_conv_length = 'auto',
    split_first_layer = True,
    filter_time_length = 10,
    n_filters_time = 40,
    n_filters_spat = 40,
    pool_mode = 'mean',
    pool_time_length = 10, #25
    pool_time_stride = 2,
    drop_prob = 0.5,
    batch_norm = True,
    batch_norm_alpha = 0.1
)

model = model_definition.construct()
model = model.to( device )
print( model )

model_parameters = filter( lambda p: p.requires_grad, model.parameters() )
params = sum( [ np.prod( p.size() ) for p in model_parameters ] )
print( f'Model has {params} trainable parameters' )

loss_fn = torch.nn.NLLLoss()
optimizer = torch.optim.AdamW( 
    model.parameters(), 
    lr = learning_rate, 
    weight_decay = weight_decay 
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max = max_epochs / 1 )

best_loss = None
best_loss_epoch = None

train_loss, test_loss, test_accuracy = [], [], []
lr = []

epoch_itr = tqdm( range( max_epochs ) )

# Calculate weights for class balancing
classes, counts = torch.unique( train_dset.prefetch_label, return_counts = True )
weights = { cl.item(): 1.0 / co.item() for cl, co in zip( classes, counts ) }
weights = [ weights[ lab.item() ] for lab in train_dset.prefetch_label ]

for epoch in epoch_itr:

    model.train()
    train_loss_batches = []
    for train_feats, train_labels in DataLoader(
        train_dset, 
        batch_size = batch_size, 
        sampler = WeightedRandomSampler( weights, len( train_dset ), replacement = False ),
        pin_memory = True if device == 'cuda' else False,
    ):
        pred = model( train_feats.to( device ) )
        loss = loss_fn( pred, train_labels.to( device ) )
        train_loss_batches.append( loss.cpu().item() )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    scheduler.step()

    lr.append( scheduler.get_last_lr()[0] )
    train_loss.append( np.mean( train_loss_batches ) )

    model.eval()
    with torch.no_grad():
        accuracy = 0
        test_loss_batches = []
        for test_feats, test_labels in DataLoader(
            test_dset, 
            batch_size = batch_size, 
            pin_memory = True if device == 'cuda' else False
        ):
            output = model( test_feats.to( device ) )
            test_loss_batches.append( loss_fn( output, test_labels.to( device ) ).cpu().item() )
            accuracy += ( output.argmax( axis = 1 ).cpu() == test_labels ).sum().item()

        test_loss.append( np.mean( test_loss_batches ) )
        test_accuracy.append( accuracy / len( test_dset ) )
        
acc_str = f'Test Accuracy: {test_accuracy[-1] * 100.0:0.2f}%'

fig, ax = plt.subplots( dpi = 100, figsize = ( 6.0, 4.0 ) )
ax.plot( train_loss, label = 'Train' )
ax.plot( test_loss, label = 'Test' )
ax.plot( test_accuracy, label = 'Test Accuracy' )
ax.plot( lr, label = 'Learning Rate' )
ax.legend()
ax.set_yscale( 'log' )
ax.set_xlabel( 'Epoch' )
ax.axhline( 1, color = 'k' )
ax.set_title( acc_str )

# out_train = input_dir / f'{tag}_train.png'
# fig.savefig( out_train )

In [None]:
model.eval()

output = [ 
    ( model( test_feats.to( device ) ).cpu().argmax( axis = 1 ), test_labels )
    for test_feats, test_labels 
    in DataLoader( test_dset, batch_size = batch_size ) 
]

decode, test_y = zip( *output )
test_y = torch.cat( test_y, axis = 0 )
decode = torch.cat( decode, axis = 0 )

classes = list( sorted( test_dset.label_dict.values() ) ) 
rev_dict = { v: k for k, v in test_dset.label_dict.items() }
class_labels = [ rev_dict[ c ] for c in classes ]
confusion = np.zeros( ( len( classes ), len( classes ) ) )
for true_idx, true_class in enumerate( classes ):
    class_trials = np.where( test_y == true_class )[0]
    for pred_idx, pred_class in enumerate( classes ):
        num_preds = ( decode[ class_trials ] == pred_class ).sum().item()
        confusion[ true_idx, pred_idx ] = num_preds / len( class_trials )


fig, ax = plt.subplots( dpi = 100 )
corners = np.arange( len( classes ) + 1 ) - 0.5
im = ax.pcolormesh( 
    corners, corners, confusion, alpha = 0.5,
    cmap = plt.cm.Blues, vmin = 0.0, vmax = 1.0
)

for row_idx, row in enumerate( confusion ):
    for col_idx, freq in enumerate( row ):
        ax.annotate( 
            f'{freq:0.2f}', ( col_idx, row_idx ), 
            ha = 'center', va = 'center' 
        )

ax.set_aspect( 'equal' )
ax.set_xticks( classes )
ax.set_yticks( classes )
ax.set_xticklabels( class_labels )
ax.set_yticklabels( class_labels )
ax.set_ylabel( 'True Class' )
ax.set_xlabel( 'Predicted Class' )
ax.invert_yaxis( )
fig.colorbar( im )
ax.set_title( acc_str )

# out_accuracy = input_dir / f'{tag}_acc.png'
# fig.savefig( out_accuracy )

In [None]:
checkpoint = {
    'model_definition': model_definition,
    'fs': trials.fs, 
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}

out_checkpoint = f'FBCSP.checkpoint'
torch.save( checkpoint, out_checkpoint )