# Simple Virtual Rat RNN

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import copy
import cPickle
from RNN import SimpleRNN
from SimRat import SimRat

from RNNfunctions import *
from dataProcessFunctions import *
from RNN_solver import RNNsolver

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
ratFile = open("allRatData.pkl","rb")
allRatData = cPickle.load(ratFile)
ratFile.close()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
N = 1
rats = preProcess(allRatData,N,ratnames=['Z009'])

In [None]:
RNNs = {}
solvers = {}
probabilities = {}
logical_accuracies = {}
real_accuracies = {}

learning_rate = 5e-4


# To find the best config
min_bias = 10
best_length= 0
best_init_params = None

repeat = 1
####### Note the key of the dictionaries!
for ratname, rat in rats.iteritems():
    print ratname
    for length in xrange(10,11):
        for i in xrange(repeat):
            print "Batch length is %d" % (length,)
            print "Training for %d / %d" % (i+1,repeat)
            RNN = SimpleRNN(N = N)
            RNNs[i] = RNN
            solver = RNNsolver(RNN, rat.trainX, rat.trainY,
                               update_rule='adam',
                               optim_config={'learning_rate': learning_rate,
                           }, num_epochs = 100,
                               lr_decay = 1,
                               batch_length = length,
                               verbose = True)
            solvers[i] = solver
            solver.train()
            choices, probs = rat.predict(RNN)
            probabilities[i] = probs

            ##############
            sample_probabilities(probs, ratname, sample = 50)

            ##############
            loss_history(solver, ratname)

            #############
            sample_correct_rate(rat, sample = 500)

            # Plot for normalization
            trial_window = 3

            real_p2a, real_a2p = realRatSwitchCost(rats,trial_window = trial_window)
            p2a, a2p = meanPerformance(rats, trial_window = trial_window)
    
            bias_p2a = bias(real_p2a, p2a)
            bias_a2p = bias(real_a2p, a2p)
            bias_mean = np.mean([bias_p2a,bias_a2p])
            
            rp2a = corr(real_p2a, p2a)
            ra2p = corr(real_a2p, a2p)
            r_mean = np.mean([rp2a,ra2p])
            
            if bias_mean < min_bias:
                min_bias = bias_mean
                best_r = r_mean
                best_length = length
                best_init_params = RNN.initparams
            elif bias_mean == min_bias and r_mean > best_r:
                min_bias = bias_mean
                best_r = r_mean
                best_length = length
                best_init_params = RNN.initparams
                
            
            print "The sum of square bias between the model and real rat's data on pro to anti is %f" % (bias_p2a,)
            print "The sum of square bias between the model and real rat's data on anti to pro is %f" % (bias_a2p,)
            print "The mean of two sum of square bias is %f" % (bias_mean,)
         
            print "The correlation coefficient between the model and real rat's data on pro to anti is %f" % (rp2a,)
            print "The correlation coefficient between the model and real rat's data on anti to pro is %f" % (ra2p,)
            print "The mean correlation coefficients is %f" % (r_mean,)

            draw_3d(real_p2a, real_a2p, p2a, a2p, trial_window = 3)

    print "The minimum bias is %f" % (min_bias,)
    print "The best correlation coefficient is %f" % (best_r,)
    print "The best batch length is %d" % (best_length,)