In [1]:
import os
import numpy as np
from loguru import logger
import pickle
import yaml

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal

from DPNNMM import DPNNMM

In [2]:
def load_config(config_path="config.yml"):
    if os.path.isfile(config_path):
        f = open(config_path)
        return yaml.load(f, Loader=yaml.FullLoader)
    else:
        raise Exception("Configuration file is not found in the path: "+config_path)

config = load_config('config_test.yml')
nn_config = config['NN_config']
mpc_config = config['mpc_config']
gym_config = config['gym_config']
dp_config = config["DP_config"]

with open('data.pickle', 'rb') as f:
    data = pickle.load(f)

data_dict = {0:[],1:[],2:[],3:[]}
for d in data:
    task_idx, obs, action, label = d
    data_dict[task_idx].append(d)

for key in data_dict.keys():
    print("task %d has data %d"%(key, len(data_dict[key])))  

task 0 has data 1117
task 1 has data 907
task 2 has data 1480
task 3 has data 1386


In [3]:
# meta fit to get a prior model
torch.manual_seed(0)
model = DPNNMM(dp_config, nn_config)

dataset = []
for i in range(3):
    dataset += data_dict[i][::2]
for j in range(20):
    model.meta_fit(dataset)

training epoch [99/100],loss train: -5.8096/0.5904, loss test  -5.8582/0.4848
training epoch [99/100],loss train: -7.2418/0.3547, loss test  -6.9613/0.4989
training epoch [99/100],loss train: -7.4842/0.1431, loss test  -7.3545/0.1527
training epoch [99/100],loss train: -7.9301/0.0687, loss test  -7.7426/0.0358
training epoch [99/100],loss train: -8.3418/0.0436, loss test  -8.1212/0.0431
training epoch [99/100],loss train: -8.8399/0.0246, loss test  -7.8157/0.0357
training epoch [99/100],loss train: -8.4985/0.0214, loss test  -7.7482/0.0405
training epoch [99/100],loss train: -8.7260/0.0200, loss test  -8.1746/0.0180
training epoch [99/100],loss train: -8.8476/0.0182, loss test  -7.9039/0.0376
training epoch [99/100],loss train: -8.7216/0.0173, loss test  -8.0827/0.0309
training epoch [99/100],loss train: -9.1407/0.0166, loss test  -7.8659/0.0172
training epoch [99/100],loss train: -9.0658/0.0146, loss test  -8.5367/0.0187
training epoch [99/100],loss train: -9.1676/0.0154, loss test  -

In [7]:
for alpha in [1]:
    print("**************** alpha:", alpha,"  ***************")
    torch.manual_seed(0)

    dp_config["alpha"] = alpha
    model = DPNNMM(dp_config, nn_config)

    dataset = []
    for i in range(4):
        dataset += data_dict[i]
    model.meta_load("meta_model.pth")
    #model.meta_fit(dataset)

    task_idx_old = 0
    pred = []
    true_task_idx = []
    times = []
    for task_idx in range(4):
        for d in data_dict[task_idx][::3]:
            model.add_data_point(d)
            if model.stm_is_full or task_idx!=task_idx_old:
                print("stm len: ", len(model.stm))
                time_use, task_idx_pred = model.fit()
                pred.append(task_idx_pred)
                true_task_idx.append(task_idx)
                times.append(time_use)

                print("pred task: ", task_idx_pred, "true:", task_idx)
            task_idx_old = task_idx

**************** alpha: 1   ***************
stm len:  100
rho old : [] rho_new:  [21301.4765625]
training epoch [99/100],loss train: 17.0688/0.0125, loss test  6.8810/0.0124
pred task:  0 true: 0
stm len:  100
rho old : [41333.8125] rho_new:  [21178.857421875]
training epoch [99/100],loss train: -11.7707/0.0030, loss test  2.9556/0.0645
pred task:  0 true: 0
stm len:  100
rho old : [61394.09375] rho_new:  [25549.64453125]
training epoch [99/100],loss train: -8.5057/0.1282, loss test  -7.4906/0.0331
pred task:  0 true: 0
stm len:  74
rho old : [6980.49365234375] rho_new:  [18677.193359375]
training epoch [99/100],loss train: -9.6876/0.0034, loss test  -8.8114/0.0136
pred task:  1 true: 1
stm len:  100
rho old : [9543.2099609375, 13724.0546875] rho_new:  [34300.203125]
training epoch [99/100],loss train: -12.6563/0.0000, loss test  -2.1655/0.0042
pred task:  2 true: 1
stm len:  100
rho old : [8536.224609375, 16291.501953125, 62574.80859375] rho_new:  [36914.68359375]
training epoch [99/1

In [8]:
task_idx_old = 0
pred = []
true_task_idx = []
times = []
for task_idx in range(4):
    for d in data_dict[task_idx][::3]:
        model.add_data_point(d)
        if model.stm_is_full or task_idx!=task_idx_old:
            print("stm len: ", len(model.stm))
            time_use, task_idx_pred = model.fit()
            pred.append(task_idx_pred)
            true_task_idx.append(task_idx)
            times.append(time_use)

            print("pred task: ", task_idx_pred, "true:", task_idx)
        task_idx_old = task_idx

stm len:  100
rho old : [4456.3427734375, 6100.25, 8694.8486328125, 3148.8525390625, 2424.306884765625] rho_new:  [10137.5615234375]
training epoch [99/100],loss train: -10.5626/0.0099, loss test  42.7360/0.0567
pred task:  5 true: 0
stm len:  100
rho old : [13761.3955078125, 15854.2333984375, 6922.10498046875, 6073.41259765625, 12035.2470703125, 24664.046875] rho_new:  [21863.71484375]
training epoch [99/100],loss train: -9.5219/0.0108, loss test  2.1573/0.0568
pred task:  5 true: 0
stm len:  100
rho old : [10069.25390625, 9543.2783203125, 595.9694213867188, 3955.014404296875, 11283.4296875, 19680.705078125] rho_new:  [17385.03125]
training epoch [99/100],loss train: -9.8160/0.0119, loss test  -5.2260/0.0591
pred task:  5 true: 0
stm len:  100
rho old : [11200.978515625, 17109.599609375, 3433.751708984375, 7880.57861328125, 11870.7548828125, 36862.5703125] rho_new:  [24329.392578125]
training epoch [99/100],loss train: -10.1058/0.0173, loss test  -5.4354/0.0265
pred task:  5 true: 0
s

In [None]:
task_idx_old = 0
pred = []
true_task_idx = []
times = []
for task_idx in range(4):
    for d in data_dict[task_idx][200:]:
        model.add_data_point(d)
        if model.stm_is_full or task_idx!=task_idx_old:
            time_use, task_idx_pred = model.fit()
            pred.append(task_idx_pred)
            true_task_idx.append(task_idx)
            times.append(time_use)
            
            print("pred task: ", task_idx_pred, "true:", task_idx)
        task_idx_old = task_idx

In [None]:
# meta training phase
torch.manual_seed(0)

model = NNComponent(NN_config=nn_config)
for i in range(4):
    for d in data_dict[i]:
        model.add_data_point(d)

In [None]:
model.fit() 

In [None]:
for i in range(4):
    test_data_loader, _ = model.make_dataset(data_dict[i])
    nll, mse = model.validate_model(test_data_loader)
    print("task ",i,"nll: ", nll, "mse: ", mse)

In [None]:
model.n_epochs = 500
model.fit(data_dict[2][:200])

In [None]:
for i in range(4):
    test_data_loader, _ = model.make_dataset(data_dict[i][100:])
    nll, mse = model.validate_model(test_data_loader)
    print("task ",i,"nll: ", nll, "mse: ", mse)

In [None]:
var_list = []
for d in data_dict[0]:
    s = [d[1]]
    a = [d[2]]
    mean, var = model.nn_meta_model.predict(s, a)
    var_list.append(var)

In [None]:
comps = []
n_comps = len(comps)
rho_old = [comps[k] for k in range(n_comps)]
n_comps

In [None]:


def CUDA(var):
    return var.cuda() if torch.cuda.is_available() else var

class MLP(nn.Module):
    
    def __init__(self, n_input=7, n_output=6, n_h=2, size_h=128):
        super(MLP, self).__init__()
        self.n_input = n_input
        self.n_output = n_output
        self.fc_in = nn.Linear(n_input, size_h)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.fc_list = nn.ModuleList()
        for i in range(n_h - 1):
            self.fc_list.append(nn.Linear(size_h, size_h))
        
        #self.fc_out = nn.Linear(size_h, n_output)
        self.fc_out_mean = nn.Linear(size_h, n_output)
        self.fc_out_var = nn.Linear(size_h, n_output)
        
        # Initialize weight
        nn.init.normal_(self.fc_in.weight, 0.0, 0.02)
        nn.init.normal_(self.fc_out_mean.weight, 0.0, 0.02)
        nn.init.normal_(self.fc_out_var.weight, 0.0, 0.02)
        
        self.fc_list.apply(self.init_normal)
        
        

    def init_normal(self, m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, 0.0, 0.02)

    def forward(self, x):
        out = x.view(-1, self.n_input)
        out = self.fc_in(out)
        out = self.relu(out)
        for _, layer in enumerate(self.fc_list, start=0):
            out = layer(out)
            out = self.relu(out)
        out_mean = self.fc_out_mean(out)
        out_var = self.fc_out_var(out)
        out_var = self.relu(out_var)
        out_var = out_var + 0.001 # add a small bias to make sure it is not equal to 0
        return (out_mean, out_var)

class NNComponent(object):
    # output: [state mean, state var]
    name = "NN"
    def __init__(self, NN_config):
        super().__init__()
        model_config = NN_config["model_config"]
        training_config = NN_config["training_config"]
        
        self.state_dim = model_config["state_dim"]
        self.action_dim = model_config["action_dim"]
        self.input_dim = self.state_dim+self.action_dim

        self.n_epochs = training_config["n_epochs"]
        self.lr = training_config["learning_rate"]
        self.batch_size = training_config["batch_size"]
        
        self.save_model_flag = training_config["save_model_flag"]
        self.save_model_path = training_config["save_model_path"]
        
        self.validation_flag = training_config["validation_flag"]
        self.validate_freq = training_config["validation_freq"]
        self.validation_ratio = training_config["validation_ratio"]

        if model_config["load_model"]:
            self.model = CUDA(torch.load(model_config["model_path"]))
        else:
            self.model = CUDA(MLP(self.input_dim, self.state_dim, model_config["hidden_dim"], model_config["hidden_size"]))

        self.mse = nn.MSELoss(reduction='mean')
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        #self.data = None
        #self.label = None
        self.dataset = []

    def criterion(self, output, label):
        nll = -self.log_likelihood(output, label) # [batch]
        return torch.mean(nll)

    def log_likelihood(self, output, label):
        mu = output[0] # [batch, state_dim]
        var = output[1] # [batch, state_dim]
        cov = torch.diag_embed(var) # [batch, state_dim, state_dim]
        m = MultivariateNormal(mu, cov)
        ll = m.log_prob(label) # [batch]
        return ll

    def predict(self, s, a):
        # convert to torch format
        s = CUDA(torch.tensor(s).float())
        a = CUDA(torch.tensor(a).float())
        inputs = torch.cat((s, a), axis=1)
        state_next = self.model(inputs)[0].cpu().detach().numpy()
        return state_next
    
    def add_data_point(self, data):
        # data format: [task_idx, state, action, next_state-state]
        self.dataset.append(data)
        
    def reset_dataset(self, new_dataset = None):
        # dataset format: list of [task_idx, state, action, next_state-state]
        if new_dataset is not None:
            self.dataset = new_dataset
        else:
            self.dataset = []
            
    def make_dataset(self, dataset, make_test_set = False):
        # dataset format: list of [task_idx, state, action, next_state-state]
        num_data = len(dataset)
        data_list = []
        for data in dataset:
            s = data[1] # state
            a = data[2] # action
            label = data[3] # here label means the next state [state dim]
            data = np.concatenate((s, a), axis=0) # [state dim + action dim]
            data_torch = CUDA(torch.Tensor(data))
            label_torch = CUDA(torch.Tensor(label))
            data_list.append([data_torch, label_torch])
            
        if make_test_set:
            indices = list(range(num_data))
            split = int(np.floor(self.validation_ratio * num_data))
            np.random.shuffle(indices)
            train_idx, test_idx = indices[split:], indices[:split]
            train_set = [data_list[idx] for idx in train_idx]
            test_set = [data_list[idx] for idx in test_idx]
            train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=self.batch_size)
            test_loader = torch.utils.data.DataLoader(test_set, shuffle=True, batch_size=self.batch_size)
        else:
            train_loader = torch.utils.data.DataLoader(data_list, shuffle=True, batch_size=self.batch_size)
            test_loader = None
        return train_loader, test_loader

    def fit(self, dataset=None):
        if dataset is not None:
            train_loader, test_loader = self.make_dataset(dataset, make_test_set=self.validation_flag)
        else: # use its own accumulated data
            train_loader, test_loader = self.make_dataset(self.dataset, make_test_set=self.validation_flag)
        
        for epoch in range(self.n_epochs):
            loss_this_epoch = []
            for datas, labels in train_loader:
                self.optimizer.zero_grad()
                outputs = self.model(datas)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                loss_this_epoch.append(loss.item())
            
            if self.save_model_flag:
                torch.save(self.model, self.save_model_path)
                
            if self.validation_flag and (epoch+1) % self.validate_freq == 0:
                loss_test, mse_test = self.validate_model(test_loader)
                loss_train, mse_train = self.validate_model(train_loader)
                print(f"training epoch [{epoch}/{self.n_epochs}],loss train: {loss_train:.4f}/{mse_train:.4f}, loss test  {loss_test:.4f}/{mse_test:.4f}")

        return np.mean(loss_this_epoch)

    def validate_model(self, testloader):
        loss_list = []
        mse_list = []
        for datas, labels in testloader:
            outputs = self.model(datas)
            loss = self.criterion(outputs, labels)
            mse_loss = self.mse(outputs[0], labels)
            loss_list.append(loss.item())
            mse_list.append(mse_loss.item())
        return np.mean(loss_list), np.mean(mse_list)
    
    def split_train_validation_old(self):
        num_data = len(self.data)
        # use validation
        if self.validation_flag:
            indices = list(range(num_data))
            split = int(np.floor(self.validation_ratio * num_data))
            np.random.shuffle(indices)
            train_idx, test_idx = indices[split:], indices[:split]

            train_set = [[self.data[idx], self.label[idx]] for idx in train_idx]
            test_set = [[self.data[idx], self.label[idx]] for idx in test_idx]

            train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=self.batch_size)
            test_loader = torch.utils.data.DataLoader(test_set, shuffle=True, batch_size=self.batch_size)
        else:
            train_set = [[self.data[idx], self.label[idx]] for idx in range(num_data)]
            train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=self.batch_size)
            test_loader = None
            
        return train_loader, test_loader

    def add_data_point_old(self, data):
        s = data[1]
        a = data[2]
        label = data[3][None] # here label means the next state
        data = np.concatenate((s, a), axis=0)[None]

        # add new data point to data buffer
        if self.data is None:
            self.data = CUDA(torch.Tensor(data))
            self.label = CUDA(torch.Tensor(label))
        else:
            self.data = torch.cat((self.data, CUDA(torch.tensor(data).float())), dim=0)
            self.label = torch.cat((self.label, CUDA(torch.tensor(label).float())), dim=0)

In [None]:
mean = torch.FloatTensor([[0,1],[1,0],[1,0]])
var = torch.FloatTensor([[1,1],[1,1],[5,5]])
cov = torch.diag_embed(var)
print(mean.shape, var.shape)
m = torch.distributions.MultivariateNormal(mean, cov)