In [1]:
%load_ext autoreload

%autoreload 2

In [13]:
from matplotlib.pyplot import cm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import pickle
import numpy as np
import random
from scipy.special import softmax as softmax
from math import ceil
from IPython.display import display, clear_output
from utils import *
sm0 = nn.Softmax(dim = 0)
sm1 = nn.Softmax(dim = 1)

In [14]:
Y = 2          #number of observations
M = 2          #number of memory states
A = 2          #number of actions
F = 1 + 1      #number of linear features (+ bias)

# optimizer parameters
lr = 0.01
n_epochs =10
n_batch = 50
threshold_act = 30
dx = 125

# trajectories limited?
maxT = 5000 

# shuffle gradient position
x_pos_shuffle = False
x_max_shuffle = 5

# percentage of training data
train_perc = 80



In [15]:
#Load the trajectories
name_traj = "./samples/trjs_N500len500test.pkl"
with open(name_traj, "rb") as f:
    trjs_dict = pickle.load(f)

# Splitting into train and test
Neps = len(trjs_dict)
Ntrain = int(Neps*train_perc/100)

trjs_train = trjs_dict[:Ntrain]
trjs_test = trjs_dict[Ntrain:]

In [16]:
# parameters for softmax policy
# dims: {Y, M, M', A} 
# F = #features + bias

#Start from a random theta ---default!
theta = torch.rand( (F, M, M, A), requires_grad=True)
theta.data = 2*theta.data-1
theta.data /= 5


#If you restart theta from a saved file
theta_is_from_restart = False

#If you start theta from a fixed set of numbers
theta_is_fixed = False

if theta_is_from_restart:
    theta_restart_file = "./results/virtual_data_antonio_FSC/theta_bacteria_FSC_M2_loglike431.11_th30_MselfconsTrue_FromRandom_FSCtrajs.dat"
    theta = torch.from_numpy(np.loadtxt(theta_restart_file).reshape(F,M,M,A))
elif theta_is_fixed:
    theta=np.array([[[[-0.61114431,-0.55987562],[ 0.59255672,0.82834098]],[[-0.88220033,0.46330714],[ 0.43192955,-0.52351326]]],
               [[[ 0.41101729,-0.31575377],[-0.64954147,0.36582982]],[[-0.24503104,0.00449758],[ 0.8396274,0.59647384]]]])
    theta=torch.from_numpy(theta.reshape(F,M,M,A))

In [17]:
#parameters for softmax of psi to get rho(m_0)
psi=torch.rand(M,requires_grad=True)
psi_uniform=False
if psi_uniform:
    psi=torch.ones(M,requires_grad=True)

In [18]:
sm0 = nn.Softmax(dim = 0)


In [23]:

#Produce trajectories
#ground_truth_theta=np.array([[[[-0.61114431,-0.55987562],[ 0.59255672,0.82834098]],[[-0.88220033,0.46330714],[ 0.43192955,-0.52351326]]], 
#[[[ 0.41101729,-0.31575377],[-0.64954147,0.36582982]],[[-0.24503104,0.00449758],[ 0.8396274,0.59647384]]]])


Ntraj=100
traj_len=500
signal=get_signal_landscape('step_like',traj_len,Ntraj)
dict_trajectories,original_theta = get_traj_from_theta(F,M,A,signal,'False',Ntraj,traj_len)

100 / 100


In [20]:
trjs_dict=dict_trajectories
# Splitting into train and test
Neps = len(trjs_dict)
Ntrain = int(Neps*train_perc/100)

trjs_train = trjs_dict[:Ntrain]
trjs_test = trjs_dict[Ntrain:]

In [21]:
theta,psi

(tensor([[[[ 0.1140, -0.1474],
           [ 0.1481, -0.0544]],
 
          [[ 0.0145, -0.1954],
           [-0.0065, -0.0348]]],
 
 
         [[[-0.1033,  0.1056],
           [ 0.1981,  0.1652]],
 
          [[ 0.0161, -0.0139],
           [ 0.0083, -0.1809]]]], requires_grad=True),
 tensor([0.8521, 0.8930], requires_grad=True))

In [None]:
optimizer_theta = torch.optim.Adam([theta], lr=lr)
optimizer_psi = torch.optim.Adam([psi], lr=lr)

count = 0

Neps = len(trjs_dict)

lr_mav = 1. / Ntrain

losses_train_theta = []
losses_test_theta = []

losses_train_psi = []
losses_test_psi = []

grad_required=True
for epochs in range(n_epochs):

    running_loss_theta = 0.
    running_loss_psi = 0.
    random.shuffle(trjs_train)

    for ibatch, batch in enumerate(batched(trjs_train, n_batch)):
        #if alternate_update:
        # Update theta while keeping psi fixed
        loss_theta = trajs_loss_eval(theta, psi.detach(), batch, trjs_train)
        loss_theta.backward()
        running_loss_theta += loss_theta.item()
            
        optimizer_theta.step()
        optimizer_theta.zero_grad()
        
        
        # Update theta while keeping psi fixed
        loss_psi= trajs_loss_eval(theta.detach(), psi, batch, trjs_train)
        loss_psi.backward()
        running_loss_psi += loss_psi.item()
            
        optimizer_psi.step()
        optimizer_psi.zero_grad()
        
        
        print(f"Batch {ibatch+1}/{ceil(Ntrain/n_batch)}: loss_psi {loss_psi.item()}")
        print(f"Batch {ibatch+1}/{ceil(Ntrain/n_batch)}: loss_theta {loss_theta.item()}", end='\r')
    
    loss_test_theta = trajs_loss_eval(theta, psi.detach(), batch, trjs_train)
    loss_test_psi = trajs_loss_eval(theta.detach(), psi, batch, trjs_train)
    
    print(f"Epoch: {epochs} \tLoss train: {running_loss_theta/ceil(Ntrain/n_batch)} \tLoss test: {loss_test_psi}")
    print(f"Epoch: {epochs} \tLoss train: {running_loss_psi/ceil(Ntrain/n_batch)} \tLoss test: {loss_test_theta}")
    losses_train_theta.append(running_loss_theta)
    losses_train_psi.append(running_loss_psi)
    losses_test_theta.append(loss_test_theta)
    losses_test_psi.append(loss_test_psi)

Batch 1/2: loss_psi 338.3437032047935
Batch 2/2: loss_psi 338.171456945359645
Epoch: 0 	Loss train: 338.4796117071954 	Loss test: 338.1714206523137
Epoch: 0 	Loss train: 338.2575800750766 	Loss test: 338.1714206523137
Batch 1/2: loss_psi 338.0321551718563
Batch 2/2: loss_psi 337.891830299064057
Epoch: 1 	Loss train: 338.1017703479613 	Loss test: 337.8917964500582
Epoch: 1 	Loss train: 337.96199273546017 	Loss test: 337.8917964500582
Batch 1/2: loss_psi 337.7478171498677
Batch 2/2: loss_psi 337.605264890144782
Epoch: 2 	Loss train: 337.81979047587583 	Loss test: 337.6052334236887
Epoch: 2 	Loss train: 337.6765410200062 	Loss test: 337.6052334236887
Batch 1/2: loss_psi 337.4696191795514
Batch 2/2: loss_psi 337.3451551954102673
Epoch: 3 	Loss train: 337.5374112084363 	Loss test: 337.3451262945051
Epoch: 3 	Loss train: 337.4073871874808 	Loss test: 337.3451262945051
Batch 1/2: loss_psi 337.23473899811853
Batch 2/2: loss_psi 337.1397484568553506
Epoch: 4 	Loss train: 337.2899188387232 	Loss