In [1]:
# imports
import os
import yaml
import numpy as np
import pandas as pd
import theano
import lasagne
import loading
from training import *
from network import *
from architectures import *

# aliases
L = lasagne.layers
nl = lasagne.nonlinearities
T = theano.tensor

# directories
headdir = os.path.expanduser('~/Google Drive/Bas Zahy Gianni - Games')
paramsdir = os.path.join(headdir, 'Analysis/0_hvh/Params/nnets/temp')
datadir = os.path.join(headdir, 'Data/model input')
resultsdir = os.path.join(headdir, 'Analysis/0_hvh/Loglik/nnets')

In [2]:
data = loading.default_loader(os.path.join(datadir, '1-4 (no computer).csv'))
hvhdata = loading.default_loader(os.path.join(datadir, '0 (with groups).csv'))
Xs = np.concatenate(hvhdata[2])
ys = np.concatenate(hvhdata[3])
Ss = np.concatenate(hvhdata[4])

# Prototyping

In [9]:
def run_full_fit(arch, archname, tune=True):
    """
    Runs the full fitting experiment, pretraining on later experiments and testing on first.
    Saves data as it goes to avoid eating memory.
    (SORT OF; break this up into two functions)
    """
    tunekws = {'freeze': True, 'exclude': [-5]}
    
    # start training
    trainer = DefaultTrainer(stopthresh=50, print_interval=20)
    net_list = trainer.train_all(architecture=arch, data=data, seed=985227)
    
    # save params
    for i, n in enumerate(net_list):
        fname = '{} {} split agg fit exp 1-4'.format(archname, i)
        n.save_params(os.path.join(paramsdir, fname))
    if tune:
        tuner = FineTuner(stopthresh=20)

        for i, n in enumerate(net_list):
            for j in range(5):
                fname = '{} {} agg fit exp 1-4 {} tune fit exp 0'.format(archname, i, j)
                params = L.get_all_param_values(n.net)
                net = tuner.train_all(architecture=arch, data=hvhdata, split=j, startparams=params, **tunekws )
                net.save_params(os.path.join(paramsdir, fname))
    
    return None

In [10]:
def load_arch(archname):
    with open('arch_specs.yaml') as archfile:
        arch = yaml.load(archfile)
    return arch[archname]

In [11]:
archname = 'multiconvX_ws_large' 
subnet_specs = [
    (4, (1, 4)), (4, (4, 1)), (4, (4, 4)), (4, (4, 5)), (4, (4,6)), (4, (4, 7)), (4, (4, 8)),
    (4, (1, 3)), (4, (3, 1)), (4, (3, 3)),
    (4, (1, 2)), (4, (2, 1)), (4, (2, 2))
]
# subnet_specs=[(4, (1, 4)), (4, (4, 1)), (4, (4, 4))]

In [12]:
archname = 'multiconvX_ws_large'
archspecs = load_arch(archname)
archs = dict()
archs[archname] = lambda input_var: multiconvX_ws(input_var, **archspecs['kwargs'])

In [13]:
run_full_fit(archs[archname], archname, tune=True) 
# rewrite to just take archname, leaving archs to be defined elsewhere


Split Number 0
(12362, 2, 4, 9)
Epoch 0 took 13.319s
	training loss:			2.9068
	validation loss:		2.6364
	validation accuracy:		19.66%
	total time elapsed:		13.588s
Epoch 20 took 12.792s
	training loss:			2.0039
	validation loss:		1.9132
	validation accuracy:		41.96%
	total time elapsed:		277.250s
Epoch 40 took 12.863s
	training loss:			1.9798
	validation loss:		1.8992
	validation accuracy:		42.03%
	total time elapsed:		540.941s
Epoch 60 took 12.911s
	training loss:			1.9595
	validation loss:		1.8814
	validation accuracy:		42.94%
	total time elapsed:		804.411s
Epoch 80 took 13.172s
	training loss:			1.9473
	validation loss:		1.8708
	validation accuracy:		43.28%
	total time elapsed:		1072.793s
Epoch 100 took 13.023s
	training loss:			1.9385
	validation loss:		1.8701
	validation accuracy:		43.22%
	total time elapsed:		1338.954s
Epoch 120 took 12.021s
	training loss:			1.9396
	validation loss:		1.8671
	validation accuracy:		43.43%
	total time elapsed:		1592.145s
Epoch 140 took 12.159s
	tr

## Subject tuning

Usually doesn't work

In [14]:
dafiname = os.path.join(datadir, '0 (with groups).csv')
subject_data = [loading.default_loader(dafiname, subject=s) for s in range(40)]
arch = archs[archname]


In [18]:
print([len(s[0]) for s in subject_data])

[161, 160, 307, 298, 92, 93, 90, 92, 36, 37, 83, 82, 153, 156, 185, 181, 130, 131, 105, 101, 138, 132, 139, 132, 109, 112, 40, 40, 202, 202, 141, 139, 202, 206, 146, 146, 201, 202, 89, 91]


In [19]:
for i in range(5):
    pafiname = '{} {} split agg fit exp 1-4.npz'.format(archname, i)
    prenet = Network(arch)
    prenet.load_params(os.path.join(paramsdir, pafiname))
    params = L.get_all_param_values(prenet.net)
    print('PREFIT {}\n'.format(i))
    
    for s in range(40):
        sdata = subject_data[s]
        num_obs = len(sdata[0])
        bs = num_obs//5
#         if num_obs > 50:
        tuner = FineTuner(stopthresh=10, batchsize=bs)
        print('SUBJECT {}\n'.format(s))
        
        for j in range(5):
            fname = '{} {} agg fit exp 1-4 {} subject {} tune fit exp 0'.format(archname, i, s, j)
            net = tuner.train_all(architecture=arch, data=sdata, split=j, startparams=params, freeze=True)
            net.save_params(os.path.join(paramsdir, fname))

PREFIT 0

SUBJECT 0

Epoch 0 took 0.059s
	training loss:			2.0268
	validation loss:		1.8872
	validation accuracy:		50.00%
	total time elapsed:		0.060s
Abandon ship!

TEST PERFORMANCE
	Stopped in epoch:		11
	Test loss:			2.2935
	Test accuracy:			31.25%

Epoch 0 took 0.056s
	training loss:			1.9694
	validation loss:		2.3003
	validation accuracy:		31.25%
	total time elapsed:		0.057s
Abandon ship!

TEST PERFORMANCE
	Stopped in epoch:		24
	Test loss:			1.9184
	Test accuracy:			50.00%

Epoch 0 took 0.053s
	training loss:			2.2484
	validation loss:		2.1174
	validation accuracy:		40.62%
	total time elapsed:		0.055s
Abandon ship!

TEST PERFORMANCE
	Stopped in epoch:		36
	Test loss:			1.6034
	Test accuracy:			46.88%

Epoch 0 took 0.052s
	training loss:			2.2492
	validation loss:		1.7420
	validation accuracy:		40.62%
	total time elapsed:		0.053s
Abandon ship!

TEST PERFORMANCE
	Stopped in epoch:		11
	Test loss:			2.1279
	Test accuracy:			34.38%

Epoch 0 took 0.052s
	training loss:			2.1747
	valid

## Data aggregation

doesn't need run more than once

In [None]:
datafilenames = ['0 (with groups)', '1 (with computer)', '2 (with computer)', '3 (with computer)', '4']
datafilenames = [os.path.join(datadir, fname + '.csv') for fname in datafilenames]
colnames = ['subject', 'color', 'bp', 'wp', 'zet', 'rt']

e0 = pd.read_csv(datafilenames[0], names=colnames+['splitno'])
e1 = pd.read_csv(datafilenames[1], names=colnames)
e2 = pd.read_csv(datafilenames[2], names=colnames)
e3 = pd.read_csv(datafilenames[3], names=colnames+['task', 'taskorder', 'session'])
e4 = pd.read_csv(datafilenames[4], names=colnames+['timecondition'])
Es = [e1, e2, e3, e4]
for i, e in enumerate(Es[1:]):
    e['subject'] = e['subject'] + Es[i-1].loc[Es[i-1]['subject']<1000, 'subject'].max()

A = pd.concat([e[colnames] for e in [e1, e2, e3, e4]])

groups = np.arange(len(A))%5 + 1
np.random.seed(100001)
np.random.shuffle(groups)
A['group'] = groups

A.to_csv(os.path.join(datadir, '1-4.csv'), encoding='ASCII', header=False, index=False)
A.loc[A['subject']<1000, :].to_csv(
    os.path.join(datadir, '1-4 (no computer).csv'), 
    encoding='ASCII', header=False, index=False
)