This branch is the software release for the 2019 paper: https://www.nature.com/articles/s41598-019-47795-0

See LICENSE.txt

Copyright 2019 Massachusetts Institute of Technology

In [None]:
%reset -f
import torch
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.signal
import scipy.io
import scipy.io.wavfile
import sklearn
import sklearn.metrics
import sklearn.preprocessing
import sklearn.feature_selection
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import os
import time
import datetime
import getpass
#import seaborn as sns
import pandas as pd
import hashlib
from importlib import reload 
from glob import glob
import subprocess
%matplotlib inline

import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

np.random.seed(0) # for reproducibility tests
torch.manual_seed(0) #for reproducibility tests

## Perform auditory attention decoding using audio reconstruction or direct classification using neural networks
#
## About
#  Greg Ciccarelli
#  February 3, 2018
#  March 22, 2018
#  June 28, 2019


In [None]:
from IPython.core.debugger import set_trace

import nipype.pipeline.engine as pe
import nipype.interfaces.utility as niu
import nipype.interfaces.io as nio

import getpass
import sys

In [None]:
user = getpass.getuser()

In [None]:
# Grab the latest eeg and cochleogram for each of the subject folders
collect = 'LL_HowTo_0DegreesSeparation'
modality = 'neuroscan'
#modality = 'dsi'

#############################

save_flag = True

num_predict = 1

hidden_size = 2 
#hidden_size = 200

num_ch_output = 1
output_size = num_ch_output * num_predict 

slow_opt_flag = False
#num_batch = int(1024) # bce
num_batch = int(5)



num_epoch = 2400 #paper
num_epoch = 2

learning_rate = 1e-3 #paper
weight_decay = 0 #paper


file_path_net = XXX_path_to_net 

# paper
# dry, de taillez, 20190505112434, 
num_context=26
file_name_net = '201806190841_1hid_b_tanh_b_htanh'
file_name_get_data = '201807141456_get_noscale_dsi'
loss_type = 'corr'

# Dry, bce, 20190503142452, 
num_context=1000
file_name_get_data = '201809272008_get_binary_conv_dry'
file_name_net = '201809272028_binary_conv_dsi'
loss_type = 'bce'

# wet, bce, 20190504085917 
num_context = 1000
file_name_net = '201809262034_binary_conv'
file_name_get_data = '201809262022_get_binary_conv'
loss_type = 'bce'

# wet, De taillez, 20190505101057, 
num_context = 26
file_name_get_data = '201806221952_get_noscale'
file_name_net = '201806190841_1hid_b_tanh_b_htanh'
loss_type = 'corr'

# wet ch sub, bce, 20190503211325, 
num_context = 1000
file_name_get_data = '201905031434_get_data_bce_wet2dry'
file_name_net = '201809272028_binary_conv_dsi'
loss_type = 'bce'

# wet ch sub, de taillez, 20190505110243, 
num_context = 26
file_name_net = '201806190841_1hid_b_tanh_b_htanh'
file_name_get_data = '201905031437_get_data_recon_wet2dry'
loss_type = 'corr'

file_path_name_net = os.path.join(file_path_net, file_name_net)


file_path_name_get_data = XXX_path_and_name_to_get_data

In [None]:
sys.path.append(os.path.split(file_path_name_get_data)[0])
module = __import__(os.path.split(file_path_name_get_data)[1])
reload(module)
load_data = getattr(module, 'load_data')
get_data = getattr(module, 'get_data')

In [None]:
subj_folder_list = XXX_list_of_subj_folder_paths

file_path_name_audio_list = []
file_path_name_eeg_list = []
for subj_folder in subj_folder_list[:]: #[:1]
    try:
        file_path_name_audio_list.append(sorted(glob(os.path.join(subj_folder, '*_Envelope100Hz.*')))[-1]) #. for real data  
        file_path_name_eeg_list.append(sorted(glob(os.path.join(subj_folder, '*_EEGF*.*')))[-1])    
    except:
        print('-- missing --')
        print(subj_folder)
print(file_path_name_audio_list)
print(file_path_name_eeg_list)

In [None]:
subj_folder

In [None]:
# Load data
audio, eeg, audio_unatt = load_data(file_path_name_audio_list[0], file_path_name_eeg_list[0])
 
print(audio.shape)
print(eeg.shape)
print(audio_unatt.shape)


a = ~np.isnan(audio)
fig, ax = plt.subplots();
ax.stem(np.sum(a, axis=1));
print(np.min(np.sum(a, axis=1)));

In [None]:
idx_keep_audioTime = np.sort(np.random.permutation(num_context)[:250]) # 250 LL

dct_params = {'idx_keep_audioTime': idx_keep_audioTime}

In [None]:
# debug get data
idx_sample = 0
X, y, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=None, 
                             num_batch=None, idx_sample=idx_sample, 
                             num_context=num_context, num_predict=num_predict, dct_params=dct_params)

if X is not None:
    print(X.shape)
    print(y.shape)
    print(z_unatt)

In [None]:
if X is not None:
    fig, ax = plt.subplots();
    ax.plot(X.data.numpy()[100].T);

In [None]:
if X is not None:
    fig, ax = plt.subplots();
    ax.plot(y.data.numpy()[:100]);

In [None]:
# debug get data
idx_sample = 0
X, y, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=None, 
                             num_batch=None, idx_sample=idx_sample, 
                             num_context=num_context, num_predict=num_predict, dct_params=None)

print(X.shape)
print(y.shape)
print(z_unatt)

fig, ax = plt.subplots();
ax.plot(y.data.numpy()[:100, 0:9]);
fig, ax = plt.subplots();
ax.plot(X.data.numpy()[:200, :10, 0]);


# Visualize differences
print(np.nanstd(X[:, 0, 0].data.numpy(), axis=0))

fig, ax = plt.subplots();
ax.stem(np.nanmean(np.nanmean(eeg, axis=2), axis=0));

fig, ax = plt.subplots();
ax.stem(np.nanmean(np.nanstd(eeg, axis=2), axis=0));

fig, ax = plt.subplots();
ax.stem(np.nanstd(eeg, axis=2)[:, 26]);

fig, ax = plt.subplots();
ax.stem(np.nanstd(audio, axis=1));

fig, ax = plt.subplots();
ax.stem(np.nanstd(audio_unatt, axis=1));

fig, ax = plt.subplots();
ax.plot(audio[0][:500]);
ax.plot(audio[-1][:500]);

In [None]:
# Check availability of data after removing nan's
eeg_1ch = np.squeeze(eeg[:, 0, :])

num_dur = np.nansum(~np.isnan(eeg_1ch), axis=1)
print(num_dur)
print(np.where(num_dur < num_context))
print(np.mean(num_dur[num_dur >= num_context]*0.01))
print(np.std(num_dur[num_dur >= num_context]* 0.01))


# Required: Define the main processing function

In [None]:
def big_node(train, test, file_path_name_audio, file_path_name_eeg, dct_params):
    """Process data and make predictions.
    
    1. Unpack parameters, define model, define data
    2. Training loop
    3. Evaluation
    4. Save
    
    Arguments
    ---------
    train : list
        Integer list of training parts
        
    test : list
        Integer test part
        
    file_path_name_audio : string
        Full path and name of the audio mat file
        
    file_path_name_eeg : string
        Full path and name of the eeg mat file
        
    dct_params: dict
        Collection of auxillary parameters
    """
    
    import numpy as np
    import scipy
    import scipy.io
    import sklearn
    import sklearn.preprocessing
    import torch
    from torch.autograd import Variable    
    import torch.nn as nn
    import torch.nn.functional as F   
    import datetime
    import time
    import os
    import matplotlib.pyplot as plt
    import sys
    from importlib import reload
    import hashlib
    from glob import glob
    import re
    import nipype
    
    ################################################################
    #      Unpack parameters, define model, define data
    ################################################################    
    # Setup the dnn, and create the monolithic block of data that will be used for training.
    
    def closs(x, y):
        xbar = torch.mean(x)
        ybar = torch.mean(y)
        num = 1. / x.numel() * torch.dot(x-xbar, y-ybar)
        denom = torch.std(x) * torch.std(y)
        return -num / denom

    num_context = dct_params['num_context']
    num_predict = dct_params['num_predict']
    num_epoch = dct_params['num_epoch']
    idx_eeg = dct_params['idx_eeg']
    save_flag = dct_params['save_flag']
    file_path_save = dct_params['file_path_save']    
    file_path_name_net= dct_params['file_path_name_net']   
    input_size = dct_params['input_size']
    hidden_size = dct_params['hidden_size']
    output_size = dct_params['output_size']
    num_batch = dct_params['num_batch']
    learning_rate = dct_params['learning_rate']
    weight_decay = dct_params['weight_decay']
    loss_type = dct_params['loss_type']
    collect = dct_params['collect']
    idx_split = dct_params['idx_split']
    random_seed_flag = dct_params['random_seed_flag']
    slow_opt_flag = dct_params['slow_opt_flag']

    
    if random_seed_flag:
        np.random.seed(idx_split)
        torch.manual_seed(idx_split)
    else:
        np.random.seed(0)
        torch.manual_seed(0)
    
    torch.backends.cudnn.deterministic=True        
    
    # Load and preprocess the data
    file_path_name_get_data = dct_params['file_path_name_get_data']
    sys.path.append(os.path.split(file_path_name_get_data)[0])
    module = __import__(os.path.split(file_path_name_get_data)[1])
    reload(module)
    get_data = getattr(module, 'get_data')
    load_data = getattr(module, 'load_data')
    
    
    # Comment out in order to have the same val set and therefore the same train set 
    # between runs
    #train = np.asarray(train)[np.random.permutation(len(train))].tolist()
    if 1:
        valset = train[-2:]
        print(valset)
        train = train[:-2]
        print(train)
    else:
        valset = []
    
    # path to folder containing the class.py module
    sys.path.append(os.path.split(file_path_name_net)[0])
    module = __import__(os.path.split(file_path_name_net)[1]) 
    reload(module) # handle case of making changes to the module- forces reload
    NN = getattr(module, 'NN')

    model = NN(input_size, hidden_size, output_size)

    num_val = len(valset)
    num_tr = len(train)     
    
    params = model.state_dict()

    if loss_type == 'mse':
        loss_fn = nn.MSELoss(size_average=True) # True = MSE vs False = sum squared
    elif loss_type == 'corr':
        loss_fn = closs
    elif loss_type == 'bce':
        loss_fn = nn.BCEWithLogitsLoss()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    if False: #torch.cuda.is_available():
        cuda_flag = True
        model.cuda()
        print('Using CUDA')
    else:
        print('No CUDA')
        cuda_flag = False

    loss_history = np.nan * np.zeros(num_epoch)
    loss_val_history = np.nan * np.zeros(num_epoch)
    model.train() # Turn on dropout, batchnorm
    #model.eval()
    
    audio, eeg, audio_unatt = load_data(file_path_name_audio, file_path_name_eeg, train=train)
        
    idx_eeg = None
    idx_sample = train[0]
    X, y, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=idx_eeg, 
                                 num_batch=None, idx_sample=idx_sample, 
                                 num_context=num_context, num_predict=num_predict, dct_params=dct_params)
    X_all = X
    y_all = y

    for idx_sample in train[1:]: #train[1:] [1:2]
        print(idx_sample)    
        X, y, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=idx_eeg, 
                                     num_batch=None, idx_sample=idx_sample, 
                                     num_context=num_context, num_predict=num_predict, dct_params=dct_params)
        if X is not None:
            X_all = torch.cat((X_all, X), dim=0)
            y_all = torch.cat((y_all, y), dim=0)
    print(X_all.shape)
    
    # Outside the loop to only form conv matrix once
    idx_val_sample = valset[0]
    Xval, yval, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=idx_eeg, 
                                 num_batch=None, idx_sample=idx_val_sample, 
                                 num_context=num_context, num_predict=num_predict, dct_params=dct_params)
    
    ################################################################
    #              Training loop
    ################################################################
    # Iterate over the dataset a fixed number of times or until an early stopping condition is reached.
    # Randomly select a new batch of training at each iteration
    
    example_val_y = np.nan
    example_val_z_unatt = np.nan
    example_val_yhat = np.nan
    idx_sample_list = np.nan * np.ones(num_epoch)
    idx_sample = train[0] # Initialize to the first training part
    idx_train = 0
    early_stop_flag = False
    early_stop_counter = 0
    start = time.perf_counter()
    t_start = datetime.datetime.now()    
    print(t_start)
    while (idx_train < num_epoch) and (not early_stop_flag):       
        if np.mod(idx_train, num_epoch/10) == 0:
            print('epoch %d ' % idx_train)
            end = time.perf_counter()
            t_end = datetime.datetime.now()
            print('Time per epoch %2.5f ticks' % ((end - start)/(num_epoch/10)))
            print((t_end - t_start)/(num_epoch/10))
            start = time.perf_counter()
            t_start = datetime.datetime.now()           
            print(t_start)                        
            
        idx_keep = np.random.permutation(X_all.data.size(0))[:num_batch]
        idx_keep = torch.from_numpy(idx_keep).type('torch.LongTensor')
        X_audio = X_all[idx_keep]
        y = y_all[idx_keep]                      
        #X_audio = X_audio + Variable(0. * torch.randn(X_audio.shape))    # Data augmentation via noise                  

        #print('-- got data--')
        if X_audio is not None:
            model.zero_grad()
            #print('-pre forward-')
            if cuda_flag:
                y = y.cuda()
                output = model.forward(X_audio.cuda())
            else:
                output = model.forward(X_audio)

            loss = loss_fn(output.view(-1), y.view(-1))
            
            optimizer.zero_grad()
            #print('opt zeroed')
            loss.backward()
            #print('loss.backward done')
            optimizer.step()
            loss_flag = 1

            if cuda_flag:
                loss = loss.cpu()
                output = output.cpu()
                y = y.cpu()
            loss_history[idx_train] = loss_flag * loss.data.numpy()
            
            if False: #loss_history[idx_train] < 0.09:
                early_stop_flag = True
                print("early_stop!")
            
            # Check validation set performance
            if (len(valset) > 0) and (np.mod(idx_train, 1) == 0): #50
                #print('--- val check ---')
                model.eval()

                idx_keep = np.sort(np.random.permutation(Xval.data.size(0))[:num_batch])
                idx_keep = torch.from_numpy(idx_keep).type('torch.LongTensor')
                X = Xval[idx_keep]
                y = yval[idx_keep]  
                
                if cuda_flag:
                    y_att = model.forward(X.cuda())
                else:
                    y_att = model.forward(X)

                if cuda_flag:
                    stat_1 =  loss_fn(y_att.view(-1), y.cuda().view(-1))
                else:
                    stat_1 = loss_fn(y_att.view(-1), y.view(-1))
                    stat_1 = stat_1.data.numpy()
                loss_val_history[idx_train] = stat_1
                model.train()
                
                example_val_y = y.cpu().data.numpy()
                example_val_yhat = y_att.cpu().data.numpy()
                
            idx_train = idx_train + 1
    
    print('-- done training --')
    print(datetime.datetime.now())
    
    ################################################################
    #              Evaluation
    ################################################################
    # Test on the train set, then test on the test set.    
    
    if True:
        example_tr_y = []
        example_tr_yhat = []
        example_tr_unatt = []
        for idx_tr in train[:1]:
            X, y, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=idx_eeg, 
                                         num_batch=num_batch, idx_sample=idx_tr, 
                                         num_context=num_context, num_predict=num_predict, dct_params=dct_params)
            if X is not None:
                model.eval()
                if cuda_flag:
                    y_att = model.forward(X.cuda())
                else:
                    y_att = model.forward(X)  
                example_tr_y.append(y.cpu().data.numpy())
                example_tr_yhat.append(y_att.cpu().data.numpy())
            if z_unatt is None:
                example_tr_unatt.append(np.array(np.nan))
            else:
                example_tr_unatt.append(z_unatt.data.numpy())
        
    if True:
        X, y, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=idx_eeg, 
                                     num_batch=None, idx_sample=test[0], 
                                     num_context=num_context, num_predict=num_predict, dct_params=dct_params)
        
        if X is not None:
            model.eval()
            if cuda_flag:
                y_att = model.forward(X.cuda())
            else:
                y_att = model.forward(X)  
            example_te_y = y.cpu().data.numpy()[None, :]
            example_te_yhat = y_att.cpu().data.numpy()[None, :]
        else:
            example_te_y = np.nan
            example_te_yhat = np.nan
        if z_unatt is None:
            example_te_unatt = np.array([np.nan])
        else:
            example_te_unatt = z_unatt.data.numpy()[None, :]
            
    ################################################################
    #              Save
    ################################################################ 
    # Save network parameters and outputs            
    
    ver_list = []
    for v in [torch, np, scipy, nipype]:
        ver_list.append(v.__name__ + "_" + v.__version__)   
    ver_list.append('python_' + sys.version)    

    if save_flag:
        dct_all = {**{'loss': loss_history, 'train': train, 'test': test, 
                             'file_path_name_audio': file_path_name_audio, 
                              'file_path_name_eeg': file_path_name_eeg, 
                              'valset': valset, 
                             'loss_val_history': loss_val_history,
                             'idx_sample_list': idx_sample_list,
                             'yValAtt': example_val_y,
                             'yValHat': example_val_yhat,
                             'yValUna': example_val_z_unatt,
                             'yTrainAtt': example_tr_y,
                             'yTrainHat': example_tr_yhat,
                             'yTrainUna': example_tr_unatt,
                             'yTestAtt': example_te_y,
                             'yTestHat': example_te_yhat,
                             'yTestUna': example_te_unatt,
                             'envTestAtt': example_te_y, #output api compatible
                             'envHatAtt': example_te_yhat, #output api compatible
                             'envTestUna': example_te_unatt, #output api compatible
                             'subjID': re.search('Subj_(\d+)_', file_path_name_audio).group(1),#output api compatible
                             'ver_list': ver_list                      
                       },      
                           
                          **dct_params}        
        
        hashstr = ''
        for key, val in {**{'train': train}, **dct_params}.items():
            if type(val) is str:
                hashstr = hashstr + key + val
            elif type(val) in [float, int]:
                hashstr = hashstr + key + str(val)
            elif type(val) in [list]:
                if type(val[0]) is str:
                    hashstr = hashstr + key + ','.join(val)
                elif type(val[0]) in [float, int]:
                    hashstr = hashstr + key + ','.join([str(i) for i in val])
        hexstamp = hashlib.md5(hashstr.encode('utf')).hexdigest()
        
        now_str = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
        file_path_name_checkpoint = os.path.join(file_path_save, 
                                                 'checkpoint_eeg2env_%s_%s.pt' 
                                                 % (hexstamp, now_str))
        torch.save({'state_dict': model.state_dict()}, file_path_name_checkpoint)
        
        print(file_path_name_checkpoint)
        # Replace all None elements of dict with NaN before saving to avoid save fail.
        for key, val in dct_all.items():
            if val is None:
                dct_all[key] = np.nan
        scipy.io.savemat(os.path.join(file_path_save, 
                                      'checkpoint_eeg2env_%s_%s.mat' 
                                       % (hexstamp, now_str)), 
                                      dct_all)
           
    model = None
    return model

# Required: Define all data splits

In [None]:
timestamp_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
# :9 same AM or PM run

In [None]:
eval_list = []
for file_path_name_audio, file_path_name_eeg in zip(file_path_name_audio_list, file_path_name_eeg_list):
    print(file_path_name_audio)
    audio, eeg, audio_unatt = load_data(file_path_name_audio, file_path_name_eeg)

    # exhaustive
    full_set = audio.shape[0]
    #full_set = 10 # debug, 4
    
    X, y, z_unatt = get_data(audio, eeg, audio_unatt=audio_unatt, idx_eeg=None, 
                                 num_batch=None, idx_sample=0, 
                                 num_context=num_context, num_predict=num_predict, dct_params=dct_params)    

    input_size = np.prod(X.shape[1:]) 
    print(input_size)
    for test in range(full_set): 
    #for test in np.random.permutation(full_set).tolist(): #If running less than a full set of splits and want to see different test partitions
        train = sorted(list(set(range(full_set)) - set([test])))
        eval_list.append([train, [test], file_path_name_audio, file_path_name_eeg, input_size])

In [None]:
# Optional: Test stability of training
## Can the identical network with identical inputs recover the same performance with/without different random seeds during initialization/training/optimization?
## DEBUG for Stability check
## Take eval list, first item, copy N times
## These should be identical runs of the network

eval_list = [eval_list[0] for i in range(len(eval_list))]

# Required: Define how many of the data splits to actually run

In [None]:
n_splits = len(eval_list)
#n_splits = 5
#n_splits = 2
#n_splits = 3
n_splits = 1

random_seed_flag = True
#random_seed_flag = False

# Create workflow

In [None]:
wf = pe.Workflow(name="wf")
for idx_b in range(n_splits):         
    timestamp = '%s_%s' % (timestamp_time, 
                                 hashlib.md5((('').join(eval_list[idx_b][2]+eval_list[idx_b][3])).encode('utf')).hexdigest())

    file_path_save = XXX_file_path_save_with_timestamp
    
    # Create the file_path_save here to avoid race conditions in the workflow
    if not os.path.exists(file_path_save):
        os.makedirs(file_path_save)
        
    # Remember, it is MUCH faster to submit lightweight arguments to a node than to submit the entire dataset.
    # That's why the dataset is loaded inside big_node.         
    node_big = pe.Node(niu.Function(input_names=['train', 'test', 
                                                 'file_path_name_audio', 
                                                 'file_path_name_eeg', 
                                                 'dct_params'],
                                    output_names=['outputs'],
                                    function=big_node),
                                    name='big_node_%03d' % idx_b)
   
    dct_params = {'idx_eeg': np.nan * np.ones(eeg.shape[1]), 
                  'num_context': num_context,
                  'num_predict' : num_predict,
                  'idx_split': idx_b,
                  'timestamp': timestamp, 
                  'file_path_save': file_path_save,
                  'file_path_name_get_data': file_path_name_get_data,
                  'save_flag':save_flag,  
                  'num_epoch': num_epoch,
                  'file_path_name_net': file_path_name_net,
                  'input_size': eval_list[idx_b][4],
                  'hidden_size': hidden_size,
                  'output_size': output_size,
                  'num_batch': num_batch,
                  'learning_rate': learning_rate,
                  'weight_decay': weight_decay,
                  'loss_type': loss_type,
                  'num_ch_output': num_ch_output,
                  'collect': collect,
                  'idx_keep_audioTime': idx_keep_audioTime, 
                  'random_seed_flag': random_seed_flag, 
                  'slow_opt_flag': slow_opt_flag}   
    
    node_big.inputs.train = eval_list[idx_b][0] #train
    node_big.inputs.test = eval_list[idx_b][1] #test
    
    node_big.inputs.file_path_name_audio = eval_list[idx_b][2] #file_path_name_audio
    node_big.inputs.file_path_name_eeg = eval_list[idx_b][3] #file_path_name_eeg

    node_big.inputs.dct_params = dct_params    
    wf.add_nodes([node_big])

In [None]:
print(file_path_save)

# Optional: Test main processing function
## Don't use nipype, just run the function

stats = big_node(train, [test], file_path_name_audio, file_path_name_eeg, dct_params)


# Required: Main Proc

In [None]:
wf.config['execution']['crashdump_dir'] = XXX_path_to_crashdumpdir
wf.base_dir = XXX_path_to_base_dir

wf.config['execution']['parameterize_dirs'] = False
wf.config['execution']['poll_sleep_duration'] = 10
wf.config['execution']['job_finished_timeout'] = 30

In [None]:
run_local_flag = True
run_local_flag = False

In [None]:
if run_local_flag:
    eg = wf.run() 
else: 
    #eg = wf.run('SLURM', plugin_args={'sbatch_args': '-p gpu --gres=gpu:tesla:2 --constraint=xeon-e5 --mem=15G'})
    eg = wf.run('SLURM', plugin_args={'sbatch_args': '--constraint=xeon-e5 --exclusive -O'})


In [None]:
print('Done successfully')

# Optional: Look at network parameters from a saved output file¶

# Look at params
module = __import__(os.path.split(file_path_name_net)[1])
reload(module)
NN = getattr(module, 'NN')

file_path_name_checkpoint = XXX_path_to_checkpoint

model = NN(input_size, hidden_size, output_size)
checkpoint = torch.load(file_path_name_checkpoint)
model.load_state_dict(checkpoint['state_dict']) 
model.eval()

#a = list(model.parameters())

#[print(a[i]) for i in range(len(a))]

p = nn.utils.parameters_to_vector(model.parameters())

p[:100]



# 