In [None]:
import pandas as pd
import numpy as np
import os, yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from tqdm import tqdm

if os.getcwd()[-4:] == "code":
    os.chdir('../')

In [None]:
with open(os.path.join("./code/params.yaml")) as f:
        params = yaml.safe_load(f)

In [38]:
train_df.to_parquet('./data/train.parquet')
val_df.to_parquet('./data/val.parquet')

In [39]:
def make_transition(data,rolling_size,batch_size,shuffle):
    df = pd.read_parquet(data)
    s_col = [x for x in df if x[:2]=='s:']
    a_col = [x for x in df if x[:2]=='a:']
    r_col = [x for x in df if x[:2]=='r:']
    dict = {}
    dict['traj'] = {}
    data_len = 0

    s,a,r1,r2,r3,s2,t  = [],[],[],[],[],[],[]
    
    for traj in tqdm(df.traj.unique()):
        df_traj = df[df['traj'] == traj]
        dict['traj'][traj] = {'s':[],'a':[],'r1':[], 'r2':[],'r3':[]}
        dict['traj'][traj]['s'] = df_traj[s_col].values.tolist()
        dict['traj'][traj]['a'] = df_traj[a_col].values.tolist()
        dict['traj'][traj]['r1'] = df_traj[r_col[0]].values.tolist()
        dict['traj'][traj]['r2'] = df_traj[r_col[1]].values.tolist()
        dict['traj'][traj]['r3'] = df_traj[r_col[2]].values.tolist()

        step_len = len(df_traj) - rolling_size - 1
        for step in range(step_len):
            s.append(dict['traj'][traj]['s'][step:step+rolling_size])
            a.append(dict['traj'][traj]['a'][step+rolling_size-1:step+rolling_size])
            r1.append(dict['traj'][traj]['r1'][step+rolling_size-1])
            r2.append(dict['traj'][traj]['r2'][step+rolling_size-1])
            r3.append(dict['traj'][traj]['r3'][step+rolling_size-1])
            s2.append(dict['traj'][traj]['s'][step+1:step+1+rolling_size])
            t.append(0)
            data_len += 1
        s.append(dict['traj'][traj]['s'][step_len:step_len+rolling_size])
        a.append(dict['traj'][traj]['a'][step_len+rolling_size-1:step_len+rolling_size])
        r1.append(dict['traj'][traj]['r1'][step_len+rolling_size-1])
        r2.append(dict['traj'][traj]['r2'][step_len+rolling_size-1])
        r3.append(dict['traj'][traj]['r3'][step_len+rolling_size-1])
        s2.append(dict['traj'][traj]['s'][step_len+1:step_len+1+rolling_size])
        t.append(1)
        data_len += 1
    
    s  = torch.FloatTensor(np.float32(s))
    a  = torch.LongTensor(np.int64(a))
    r1 = torch.FloatTensor(np.float32(r1))
    r2 = torch.FloatTensor(np.float32(r2))
    r3 = torch.FloatTensor(np.float32(r3))
    s2 = torch.FloatTensor(np.float32(s2))
    t  = torch.FloatTensor(np.float32(t))

    rt = DataLoader(TensorDataset(s, a, r1, r2, r3, s2, t),batch_size,shuffle)
    return rt, data_len

In [40]:
train_loader, train_len = make_transition(params['train'],rolling_size=24,batch_size=64,shuffle=True)
val_loader, val_len = make_transition(params['val'],rolling_size=24,batch_size=256,shuffle=False)

100%|██████████| 11569/11569 [00:27<00:00, 413.66it/s]
100%|██████████| 771/771 [00:01<00:00, 760.89it/s]


In [31]:
class imvt(torch.jit.ScriptModule):
    __constants__ = ['input_dim', 'n_units']
    def __init__(self, input_dim, output_dim, n_units, device, init_std=0.02):
        super().__init__()
        self.U_j = nn.Parameter(torch.randn(input_dim, 1, n_units)*init_std)
        self.U_i = nn.Parameter(torch.randn(input_dim, 1, n_units)*init_std)
        self.U_f = nn.Parameter(torch.randn(input_dim, 1, n_units)*init_std)
        self.U_o = nn.Parameter(torch.randn(input_dim, 1, n_units)*init_std)
        self.W_j = nn.Parameter(torch.randn(input_dim, n_units, n_units)*init_std)
        self.W_i = nn.Parameter(torch.randn(input_dim, n_units, n_units)*init_std)
        self.W_f = nn.Parameter(torch.randn(input_dim, n_units, n_units)*init_std)
        self.W_o = nn.Parameter(torch.randn(input_dim, n_units, n_units)*init_std)
        self.b_j = nn.Parameter(torch.randn(input_dim, n_units)*init_std)
        self.b_i = nn.Parameter(torch.randn(input_dim, n_units)*init_std)
        self.b_f = nn.Parameter(torch.randn(input_dim, n_units)*init_std)
        self.b_o = nn.Parameter(torch.randn(input_dim, n_units)*init_std)
        self.F_alpha_n = nn.Parameter(torch.randn(input_dim, n_units, 1)*init_std)
        self.F_alpha_n_b = nn.Parameter(torch.randn(input_dim, 1)*init_std)
        self.F_beta = nn.Linear(2*n_units, 1)
        self.Phi = nn.Linear(2*n_units, output_dim)
        self.n_units = n_units
        self.input_dim = input_dim
        self.device = device
    
    @torch.jit.script_method
    def forward(self, x):
        h_tilda_t = torch.zeros(x.shape[0], self.input_dim, self.n_units).to(self.device)
        c_tilda_t = torch.zeros(x.shape[0], self.input_dim, self.n_units).to(self.device)
        outputs = torch.jit.annotate(List[Tensor], [])
        for t in range(x.shape[1]):
            j_tilda_t = torch.tanh(torch.einsum("bij,ijk->bik", h_tilda_t, self.W_j) + \
                                   torch.einsum("bij,jik->bjk", x[:,t,:].unsqueeze(1), self.U_j) + self.b_j)
            i_tilda_t = torch.sigmoid(torch.einsum("bij,ijk->bik", h_tilda_t, self.W_i) + \
                                torch.einsum("bij,jik->bjk", x[:,t,:].unsqueeze(1), self.U_i) + self.b_i)
            f_tilda_t = torch.sigmoid(torch.einsum("bij,ijk->bik", h_tilda_t, self.W_f) + \
                                torch.einsum("bij,jik->bjk", x[:,t,:].unsqueeze(1), self.U_f) + self.b_f)
            o_tilda_t = torch.sigmoid(torch.einsum("bij,ijk->bik", h_tilda_t, self.W_o) + \
                                torch.einsum("bij,jik->bjk", x[:,t,:].unsqueeze(1), self.U_o) + self.b_o)
            c_tilda_t = c_tilda_t*f_tilda_t + i_tilda_t*j_tilda_t
            h_tilda_t = (o_tilda_t*torch.tanh(c_tilda_t))
            outputs += [h_tilda_t]
        outputs = torch.stack(outputs)
        outputs = outputs.permute(1, 0, 2, 3)
        
        alphas = torch.tanh(torch.einsum("btij,ijk->btik", outputs, self.F_alpha_n)+self.F_alpha_n_b)
        alphas = torch.exp(alphas)
        alphas = alphas/torch.sum(alphas, dim=1, keepdim=True)
        g_n = torch.sum(alphas*outputs, dim=1)
        hg = torch.cat([g_n, h_tilda_t], dim=2)
        mu = self.Phi(hg)
        betas = torch.tanh(self.F_beta(hg))
        betas = torch.exp(betas)
        betas = betas/torch.sum(betas, dim=1, keepdim=True)
        mean = torch.sum(betas*mu, dim=1)

        return mean, alphas, betas

In [None]:
device = params['device']
#network = IMVFullLSTM(state_dim=params['state_dim'], nb_actions=params['num_actions'], n_units=params['n_units'], device=params['device']).to(device)
#target_network = IMVFullLSTM(state_dim=params['state_dim'], nb_actions=params['num_actions'], n_units=params['n_units'], device=params['device']).to(device)
#network = QNetwork(state_dim=params['state_dim'], rolling_size=1,nb_actions=12).to(device)
#target_network = QNetwork(state_dim=params['state_dim'], rolling_size=1,nb_actions=12).to(device)
network = imvt(params['state_dim'], params['num_actions'], n_units=32, device=params['device']).to(device)
target_network = imvt(params['state_dim'], params['num_actions'], n_units=32, device=params['device']).to(device)

epoch = 200
gamma = 1.0
optimizer = optim.Adam(network.parameters(), lr=params['learning_rate'], amsgrad=True)

update_freq = 2

for i in range(epoch):
    loss_train = 0
    update_counter = 0
    for s,a,r1,r2,r3,s2,t in tqdm(train_loader):
        s = s.to(device)
        a = a.to(device)
        r1 = r1.to(device)
        r2 = r2.to(device)
        r3 = r3.to(device)
        s2 = s2.to(device)
        t = t.to(device)

        q,_,_ = network(s)
        q2,_,_ = target_network(s2)
        q2 = q2.detach()
        q_pred = q.gather(1, a.squeeze(1)).squeeze()

        q2_net, _, _ = network(s2)
        q2_net = q2_net.detach()
        q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze(1)

        bellman_target = torch.clamp(r1, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
        loss = F.smooth_l1_loss(q_pred, bellman_target)

        optimizer.zero_grad()
        loss.backward()
        loss_train += loss.item()
        optimizer.step()

        update_counter += 1
        if update_counter == update_freq:
            target_network.load_state_dict(network.state_dict())
            update_counter = 0

    with torch.no_grad():
        loss_val = 0

        for s,a,r1,r2,r3,s2,t in val_loader:
            s = s.to(device)
            a = a.to(device)
            r1 = r1.squeeze().to(device)
            r2 = r2.squeeze().to(device)
            r3 = r3.squeeze().to(device)
            s2 = s2.to(device)
            t = t.to(device)

            q,_,_ = network(s)
            q2,_,_ = target_network(s2)
            q2 = q2.detach()
            q_pred = q.gather(1, a.squeeze(1)).squeeze()

            q2_net,_,_ = network(s2)
            q2_net = q2_net.detach()
            q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze()

            bellman_target = torch.clamp(r1, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
            loss = F.smooth_l1_loss(q_pred, bellman_target)
            loss_val += loss.item()

    print("Iter: ", i, "train: ",loss_train/train_len, "val: ",loss_val/val_len)