In [None]:
import os
import twa
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch import nn
import torch.optim as optim

from twa.utils import ensure_dir, write_yaml
from twa.data.ode import FlowSystemODE
import random
from twa.train import VecTopoDataset, train_model_alt, predict_model

random.seed(2)
torch.manual_seed(2)

torch.use_deterministic_algorithms(True) 
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_lattice = 64

# Classify point vs cycle attractor

In [None]:
# training configuration
outdir = '../output/'
data_dir = os.path.join(outdir, 'data') 
dim = 2
batch_size = 64

train_data_descs = ['simple_oscillator_nsfcl']

train_data_desc_ = '_'.join(train_data_descs)

kwargs_train = {}

with_attention = False # True 
datatype = 'angle'
model_type = None

save = False; save_dir = None

with_attention_str = 'atten' if with_attention else 'noatten'
if model_type is None:
    exp_desc = train_data_desc_ + '_' + datatype + '_' + with_attention_str
else:
    exp_desc = train_data_desc_ + '_' + datatype + '_' + model_type

datasize = 10000

In [None]:
# combine train datasets in case multiple are given
for itrain_data_desc, train_data_desc in enumerate(train_data_descs):
    train_data_dir = os.path.join(data_dir, train_data_desc)
    if itrain_data_desc == 0:
        train_dataset = VecTopoDataset(train_data_dir, datatype=datatype, datasize=datasize, filter_outbound=True)
        train_dataset.plot_data()
    else:
        train_dataset += VecTopoDataset(train_data_dir, datatype=datatype, datasize=datasize, filter_outbound=True)
        
    

In [None]:
# train
model, losses = train_model_alt(train_dataset, model_type=model_type, with_attention=with_attention, device=device, verbose=True)

In [None]:
# plot train loss
plt.plot(losses)
plt.xlabel('iteration')
plt.ylabel('loss')
plt.title(exp_desc)
plt.ylim(0, 1)
plt.show()

In [None]:
# visualize train results
# plot examples of attention masks
model.plot_attention(train_dataset, n_samples=9)


correct, auc, output = predict_model(model, train_dataset)
fig, ax = plt.subplots(1,2, figsize=(15,5))
sysp = train_dataset.sysp

twa.utils.plot_diverge_scale(sysp[:,0], sysp[:,1], output[:,0], ax=ax[0], title='pt logit')
twa.utils.plot_diverge_scale(sysp[:,0], sysp[:,1], output[:,1], ax=ax[1], title='cycle logit')

plt.show()

# Examine test data

## Accuracy

In [None]:

results_dir = os.path.join(outdir, 'results')
# save = True; save_dir = None
exp_results_dir = os.path.join(results_dir, exp_desc)
ensure_dir(exp_results_dir)

test_data_descs = [
    'simple_oscillator_noaug',
    'simple_oscillator_nsfcl',
    'suphopf',
    'lienard_poly',
    'lienard_sigmoid',
    'vanderpol',
    'bzreaction',
    'selkov',
    'selkov2',
    'repressilator',
    'pancreas_clusters_random_bin',
]


tt = 'test'
res = []
for test_data_desc in test_data_descs:
    print(test_data_desc)
    test_data_dir = os.path.join(data_dir, test_data_desc)
    if os.path.isdir(test_data_dir):
        test_dataset = VecTopoDataset(test_data_dir,  tt=tt, datatype=datatype) 

        if save:
            save_dir = os.path.join(exp_results_dir, test_data_desc)
            ensure_dir(save_dir)

        correct, auc, _ = predict_model(model, test_dataset, verbose=False, save=save, save_dir=save_dir)
        res.append({'data': os.path.basename(test_data_dir),
                    'correct': correct,
                    'auc': auc})

pd.DataFrame(res)

In [None]:
test_data_dir = os.path.join(os.path.join(outdir, 'data'), 'selkov')
test_dataset = VecTopoDataset(test_data_dir,  tt=tt, datatype=datatype) 
correct, auc, _ = predict_model(model, test_dataset, verbose=False, save=save, save_dir=save_dir)
print(correct, auc)

test_data_dir = os.path.join(os.path.join(outdir, 'dataprev'), 'selkov_new')
test_dataset = VecTopoDataset(test_data_dir,  tt=tt, datatype=datatype) 
correct, auc, _ = predict_model(model, test_dataset, verbose=False, save=save, save_dir=save_dir)
print(correct, auc)



In [None]:
test_data_desc = 'selkov2'
test_data_dir = os.path.join(data_dir, test_data_desc)
test_dataset = VecTopoDataset(test_data_dir,  tt=tt, datatype=datatype) 

model.plot_attention(test_dataset, n_samples=9)
plt.show()

correct, auc, output = predict_model(model, test_dataset)
fig, ax = plt.subplots(1,4, figsize=(20,5))
sysp = test_dataset.sysp
idx0 = twa.dt.Selkov.plot_param_idx[0]
idx1 = twa.dt.Selkov.plot_param_idx[1]
ax[0].scatter(sysp[:,idx0], sysp[:,idx1], c=test_dataset.label[:,0])
ax[1].scatter(sysp[:,idx0], sysp[:,idx1], c=output[:,0] > 0)
ax[1].scatter(sysp[:,idx0], sysp[:,idx1], c=output[:,0] > 0)
twa.utils.plot_diverge_scale(sysp[:,idx0], sysp[:,idx1], output[:,0], ax=ax[2], title='pt logit')
twa.utils.plot_diverge_scale(sysp[:,idx0], sysp[:,idx1], output[:,1], ax=ax[3], title='cycle logit')

plt.show()

In [None]:
# test with noise
tt = 'test'
noise = 0.5
test_data_dir = os.path.join(data_dir, train_data_desc_)
test_data_desc = train_data_desc_ + '_noise%.2f' % noise
test_dataset = VecTopoDataset(test_data_dir,  tt=tt, datatype=datatype, noise=noise) 

if save:
    save_dir = os.path.join(exp_results_dir, test_data_desc)
    ensure_dir(save_dir)

correct, auc, _ = predict_model(model, test_dataset, verbose=False, save=save, save_dir=save_dir)
print(f'{train_data_desc_}, with Gaussian noise of scale {noise}, {correct:.2f} correct, {auc:.2f} auc')

test_dataset.plot_data()

tt = 'test'
mask_prob = 0.25
test_data_desc = train_data_desc_ + '_masked%.2f' % mask_prob
test_dataset = VecTopoDataset(test_data_dir,  tt=tt, datatype=datatype, mask_prob=mask_prob) 

if save:
    save_dir = os.path.join(exp_results_dir, test_data_desc)
    ensure_dir(save_dir)

correct, auc, _ = predict_model(model, test_dataset, verbose=False, save=save, save_dir=save_dir)
res.append({'data': os.path.basename(test_data_dir),
            'correct': correct,
            'auc': auc})

test_dataset.plot_data()
print(f'{train_data_desc_}, with mask probability {mask_prob}, {correct:.2f} correct, {auc:.2f} auc')


In [None]:
if save:
    models_dir = os.path.join(outdir, 'models')
    exp_models_dir = os.path.join(models_dir, exp_desc)
    ensure_dir(exp_models_dir)

    torch.save(model.state_dict(), os.path.join(exp_models_dir, 'model.pt'))