In [1]:
'''
Learn models with different options, for comparison plots

'''

# import libraries
import torch
from robot_env import robot_env
import numpy as np
from pmlp import pmlp
import pgnn
from utils import get_sampleable_inds, sample_memory
from utils import to_body_frame_batch, from_body_frame_batch
from utils import state_diff_batch, state_to_fd_input, state_add_batch
from utils import divide_state, to_device, detach_list, clip_grads
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

import os
# cwd = os.path.dirname(os.path.realpath(__file__))
cwd = ''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
folder = os.path.join(cwd, 'learned_models')
if not(os.path.exists(folder)):
    os.mkdir(folder)
    print('Created folder ' + folder )
else:
    print('Using folder '  + folder)
    
start_time = datetime.now()
start_time_str = datetime.strftime(start_time, '%Y%m%d_%H%M')

RUN_MLP = True
RUN_SINGLE_GNN = True
RUN_MULTI_GNN = True

Using folder learned_models


In [16]:
# load dataset and gather module configs

urdf_names = ['wnwwnw', 'llllll', 'llwwll']
# urdf_names = ['llllll', 'llwwll', 
# 'lwllwl', 'lwwwwl', 
# 'lnllnl', 'lnwwnl',
# 'wlwwlw',  'wwllww', 
# 'wwwwww',  'wnllnw', 
#  'wllllw', 'wnwwnw']

envs = dict()
run_lens = dict()
states_memory_tensors = dict()
actions_memory_tensors = dict()
module_sa_len = dict()
modules_types = dict()
attachments = dict()
delta_fd_len = dict()

for urdf in urdf_names:

    env = robot_env(show_GUI = False)
    env.reset_terrain()
    env.reset_robot(urdf_name=urdf, randomize_start=False)
    attachments[urdf] = env.attachments
    modules_types[urdf] = env.modules_types
    print('attachments: ' + str(attachments[urdf]))
    print('modules_types: ' + str(modules_types[urdf]))
    n_modules = len(modules_types[urdf])
    envs[urdf] = env
    state = env.get_state()
    state_t = [torch.tensor(s, dtype=torch.float32).unsqueeze(0) for s in state]
    module_state_len = []
    for s in state:
        module_state_len.append(len(s))
    module_action_len = list(np.diff(env.action_indexes))
    state_len = sum(module_state_len)
    module_sa_len[urdf] = module_state_len+ module_action_len
    
#     fd_input_test, delta_fd_test = to_body_frame_batch(state_t, state_t)
#     delta_fd_len[urdf] = torch.cat(delta_fd_test,-1).shape[-1]
    delta_fd_len[urdf] = sum(module_sa_len[urdf][0:len(modules_types[urdf])])
    print('Delta fd len:' , delta_fd_len[urdf])

attachments: [[1, None, 2, 3, None, 4], [0], [0], [0], [0]]
modules_types: [0, 2, 2, 2, 2]
Delta fd len: 24
attachments: [[1, 2, 3, 4, 5, 6], [0], [0], [0], [0], [0], [0]]
modules_types: [0, 1, 1, 1, 1, 1, 1]
Delta fd len: 48
attachments: [[1, 2, 3, 4, 5, 6], [0], [0], [0], [0], [0], [0]]
modules_types: [0, 1, 1, 2, 2, 1, 1]
Delta fd len: 42


In [3]:
for urdf in urdf_names:
    # load dataset
    file_names = []
    folder = 'random_rollouts/'
    found = True
    fname_test = os.path.join(cwd, folder+urdf+'_random_rollouts.ptx')
    if os.path.isfile(fname_test):
        file_names.append(fname_test)
#     f_index = 1
#     while found:
#         fname_test = os.path.join(cwd,
#             folder+urdf+'_rollouts' + str(int(f_index)) + '.ptx')
#         found = os.path.isfile(fname_test)
#         if found: 
#             file_names.append(fname_test)
#             f_index+=1
    print('Found files ')
    print(str(file_names))

    states_memory = []
    actions_memory = []
    run_lens[urdf] = []

    for fname in file_names:
        print('loading ' + fname )
        data_in = torch.load(fname)
        states_memory += data_in['states_memory']
        actions_memory += data_in['actions_memory']
        run_lens[urdf] += data_in['run_lens']
        del data_in

    states_memory_tensors[urdf] = [torch.cat(s,0) for s in list(zip(*states_memory)) ]
    actions_memory_tensors[urdf] = [torch.cat(s,0) for s in list(zip(*actions_memory)) ]

print('loaded and merged data')
    
batch_size_default = 500 # default batch size

Found files 
['random_rollouts/wnwwnw_random_rollouts.ptx']
loading random_rollouts/wnwwnw_random_rollouts.ptx
Found files 
['random_rollouts/llllll_random_rollouts.ptx']
loading random_rollouts/llllll_random_rollouts.ptx
loaded and merged data


In [6]:
# urdf_to_train = urdf_names[1]

# select options:
n_training_steps = 30000

condition_tuples = [(100, 2),
                    (1000, 2),
                    (5000, 2)]

#                     (100, 10),
#                     (1000, 10),
#                     (5000, 10)]

# # select data regime (number of data samples to use, low mid or high)
# n_rollouts_to_use = 100 # Low
# n_rollouts_to_use = 1000 # mid
# n_rollouts_to_use = 5000 # high

# # select sequence length for multistep loss
# seq_len = 10
# # seq_len = 2

for condition_tuple in condition_tuples:
    n_rollouts_to_use = condition_tuple[0]
    seq_len = condition_tuple[1]
    condition_run_folder = os.path.join(folder,'runs', start_time_str+ '_' + str(n_rollouts_to_use) + '_' + str(seq_len))
    if not(os.path.exists(condition_run_folder)):
        os.mkdir(condition_run_folder)
        print('Created folder ' + condition_run_folder )
    else:
        print('Using folder '  + condition_run_folder)
    


Created folder random_rollouts/runs/20200710_1645_100_2
Created folder random_rollouts/runs/20200710_1645_1000_2
Created folder random_rollouts/runs/20200710_1645_5000_2


In [21]:
# train P-MLP


# urdf = urdf_to_train
if RUN_MLP:    
    for urdf in urdf_names:

        for condition_tuple in condition_tuples:
            n_rollouts_to_use = condition_tuple[0]
            seq_len = condition_tuple[1]

            print('--- Running condition: ' 
                  + 'n_rollouts=' + str(n_rollouts_to_use)
                  + 'seq_len=' + str(seq_len))

            comment_str = '_pmlp_' + urdf + str(n_rollouts_to_use) + '_' + str(seq_len)
            writer = SummaryWriter(log_dir = os.path.join(folder,'runs',
                        start_time_str+ '_' + str(n_rollouts_to_use) + '_' + str(seq_len)),
                        comment=comment_str)

            # depending on the length of the multistep sequence we want,
            # only some indexes of the full set of states collected can be sampled.
            sampleable_inds = dict()
            batch_sizes = dict()
        #     for urdf in urdf_names:

            sampleable_inds[urdf] = get_sampleable_inds(
                run_lens[urdf][:n_rollouts_to_use], seq_len)
            n_sampleable = len(sampleable_inds[urdf])
            batch_sizes[urdf] = batch_size_default
            if batch_sizes[urdf] > n_sampleable:
                batch_sizes[urdf] = n_sampleable
            print(urdf + ' using ' + str(n_rollouts_to_use) + ' out of Rollouts ' + str(len(run_lens[urdf])))



            batch_size = batch_sizes[urdf]
            n_modules = len(modules_types[urdf])
            module_state_len = module_sa_len[urdf][:n_modules]
            # initialize network and optimizer
            input_len = sum(module_sa_len[urdf]) - 3
            output_len = sum(module_state_len)
            hidden_layer_size = 300
            n_hidden_layers = 5
            fd_network = pmlp(input_len = input_len, output_len=output_len,
                n_hidden_layers = n_hidden_layers, hidden_layer_size=hidden_layer_size
                ).to(device)
            weight_decay = 1e-4
            optimizer =  torch.optim.Adam(fd_network.parameters(),lr=1e-3, weight_decay = weight_decay) 

            num_nn_params=0
            for p in fd_network.parameters():
                nn=1
                for s in list(p.size()):
                    nn = nn*s
                num_nn_params += nn
            print('Num NN params: ' + str(num_nn_params))

            for training_step in range( n_training_steps):
    #         for training_step in range(0):

                if np.mod(training_step,5000 )==0 and training_step>10000:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = param_group['lr']/2
                        print( 'LR: ' + str(param_group['lr']) )

                # sample without replacement from the full memory, depending on what is sampleable
                state_seq, action_seq, sampled_inds = sample_memory(
                                states_memory_tensors[urdf], 
                                actions_memory_tensors[urdf],
                                sampleable_inds[urdf], seq_len, batch_size)

                loss = 0
                state_approx = to_device(state_seq[0],device) # initial state input is the first in sequence
                for seq in range(seq_len-1): # for multistep loss, go through the sequence

                    # process states to move them to vehicle frame
                    fd_input_real, delta_fd_real = to_body_frame_batch(state_seq[seq], state_seq[seq+1])
                    fd_input_approx, R_t = state_to_fd_input(state_approx) # for recursive estimation

                    # pass through network
                    fd_input = torch.cat(fd_input_approx,1).to(device)
                    actions_in = torch.cat(action_seq[seq],1).to(device)
                    delta_fd = torch.cat(delta_fd_real,1).to(device)
                    state_delta_est_mean, state_delta_est_var = fd_network(fd_input, actions_in)

                    # compute loss for this step in sequence
                    loss += torch.sum(
                         (state_delta_est_mean - delta_fd)**2/state_delta_est_var
                         + torch.log(state_delta_est_var)
                         )/batch_size/(seq_len-1)   

                    # transform back to world frame advance to next sequence step
                    if seq_len>2:
                        # divide MLP output divided up into modules
                        delta_fd_approx = divide_state(state_delta_est_mean, module_state_len)
                        # update recursive state estimation for multistep loss  
                        state_approx = from_body_frame_batch(state_approx, delta_fd_approx)


                # backprop and optimizer step 
                loss_np = loss.detach().cpu().numpy()
                fd_network.zero_grad()
                loss.backward()
                optimizer.step()

                writer.add_scalar('Train' + '/Loss_pmlp', loss_np, training_step)


                # periodically save the model

                if np.mod(training_step,500)==0:
                    PATH = ('learned_models/' + urdf + '_pmlp_r' + str(int(n_rollouts_to_use)) + 
                            '_ms'+ str(int(seq_len))+'.pt')
                    PATH = os.path.join(cwd, PATH)
                    fd_network_state_dict=fd_network.state_dict()
                    torch.save({'fd_network_state_dict':fd_network_state_dict,
                        'fd_network_input_len':fd_network.input_len,
                        'fd_network_output_len':fd_network.output_len,
                        'fd_network_n_hidden_layers':fd_network.n_hidden_layers,
                        'fd_network_hidden_layer_size':fd_network.hidden_layer_size,
                        'n_rollouts_to_use':n_rollouts_to_use,
                        'seq_len':seq_len,
                        'batch_size':batch_size,
                        'urdf':urdf,
                        'num_nn_params':num_nn_params,
                        'weight_decay':weight_decay
                        },  PATH)  
                    print('Training losses at iter ' + 
                            str(training_step) + ': ' + 
                            str(np.round(loss_np,2)))
            del fd_input, actions_in, delta_fd, fd_network, loss, optimizer
            torch.cuda.empty_cache()

--- Running condition: n_rollouts=100seq_len=2
wnwwnw using 100 out of Rollouts 5000
Num NN params: 790446


NameError: name 'fd_input' is not defined

In [22]:
# train P-GNN
# urdf = urdf_to_train
if RUN_SINGLE_GNN:
    for urdf in urdf_names:
        for condition_tuple in condition_tuples:

            n_rollouts_to_use = condition_tuple[0]
            seq_len = condition_tuple[1]

            comment_str = '_pgnn_' + urdf + str(n_rollouts_to_use) + '_' + str(seq_len)
            writer = SummaryWriter(log_dir = os.path.join(folder,'runs',
                        start_time_str+ '_' + str(n_rollouts_to_use) + '_' + str(seq_len)),
                        comment=comment_str)


            # depending on the length of the multistep sequence we want,
            # only some indexes of the full set of states collected can be sampled.
            sampleable_inds = dict()
            batch_sizes = dict()
        #     for urdf in urdf_names:
            sampleable_inds[urdf] = get_sampleable_inds(run_lens[urdf][:n_rollouts_to_use], seq_len)
            n_sampleable = len(sampleable_inds[urdf])
            batch_sizes[urdf] = batch_size_default
            if batch_sizes[urdf] > n_sampleable:
                batch_sizes[urdf] = n_sampleable
            print(urdf + ' using ' + str(n_rollouts_to_use) + ' out of Rollouts ' + str(len(run_lens[urdf])))


            batch_size = batch_sizes[urdf]
            n_modules = len(modules_types[urdf])
            module_state_len = module_sa_len[urdf][:n_modules]

            # initialize network and optimizer
            internal_state_len = 100
            message_len = 50
            hidden_layer_size = 250
            weight_decay = 1e-4
            gnn_nodes = pgnn.create_GNN_nodes(internal_state_len, message_len, hidden_layer_size, 
                            device, body_input = True)
            optimizer = torch.optim.Adam(pgnn.get_GNN_params_list(gnn_nodes), 
                                         lr=1e-3,
                                weight_decay= weight_decay)# create module containers for the nodes
            modules = []
            for i in range(n_modules):
                modules.append(pgnn.Module(i, gnn_nodes[modules_types[urdf][i]], device))


            num_nn_params=0
            for p in pgnn.get_GNN_params_list(gnn_nodes):
                nn=1
                for s in list(p.size()):
                    nn = nn*s
                num_nn_params += nn

            print('Num NN params: ' + str(num_nn_params))


            for training_step in range(n_training_steps):
        #     for training_step in range(0):

                if np.mod(training_step,5000 )==0 and training_step>10000:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = param_group['lr']/2
                        print( 'LR: ' + str(param_group['lr']) )

                # sample without replacement from the full memory, depending on what is sampleable
                state_seq, action_seq, sampled_inds = sample_memory(
                                states_memory_tensors[urdf], actions_memory_tensors[urdf],
                                sampleable_inds[urdf], seq_len, batch_size)

                loss = 0
                state_approx = to_device(state_seq[0],device) # initial state input is the first in sequence
                for seq in range(seq_len-1): # for multistep loss, go through the sequence

                    for module in modules: # must reset module lstm state
                        module.reset_hidden_states(batch_size) 

                    # process states to move them to vehicle frame
                    fd_input_real, delta_fd_real = to_body_frame_batch(state_seq[seq], state_seq[seq+1])
                    fd_input_approx, R_t = state_to_fd_input(state_approx) # for recursive estimation

                    # pass through network
                    fd_input   = to_device(fd_input_approx, device) 
                    actions_in = to_device(action_seq[seq], device)
                    delta_fd   = to_device(delta_fd_real, device) 
                    node_inputs = [torch.cat([s,a],1) for (s,a) in zip(fd_input, actions_in)]
                    state_delta_est_mean, state_delta_var = pgnn.run_propagations(
                        modules, attachments[urdf], 2, node_inputs, device)

                    # compute loss for this step in sequence
                    for mm in range(len(state_delta_est_mean)):
                        loss += torch.sum(
                            (state_delta_est_mean[mm] - delta_fd[mm])**2/state_delta_var[mm] + 
                            torch.log(state_delta_var[mm]) 
                                        )/batch_size/(seq_len-1)

                    # transform back to world frame advance to next sequence step
                    if seq_len>2:
                        # update recursive state estimation for multistep loss
                        # GNN output is already divided up into modules
                        delta_fd_approx = state_delta_est_mean
                        state_approx = from_body_frame_batch(state_approx, delta_fd_approx)



                # backprop and optimizer step 
                loss_np=(loss.detach().cpu().numpy())
                pgnn.zero_grads(gnn_nodes)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                writer.add_scalar('Train' + '/Loss_pgnn', loss_np, training_step)


                # periodically save the model
                if np.mod(training_step,500)==0:
                    PATH = ('learned_models/' + urdf + '_pgnn_r' + str(int(n_rollouts_to_use)) + 
                            '_ms'+ str(int(seq_len))+'.pt')
                    PATH = os.path.join(cwd, PATH)
                    gnn_state_dicts=(pgnn.get_state_dicts(gnn_nodes))
                    save_dict = dict()
                    save_dict['gnn_state_dicts'] =  gnn_state_dicts
                    save_dict['internal_state_len'] = gnn_nodes[0].internal_state_len
                    save_dict['message_len'] = gnn_nodes[0].message_len
                    save_dict['hidden_layer_size'] = gnn_nodes[0].hidden_layer_size
                    save_dict['n_rollouts_to_use']=n_rollouts_to_use
                    save_dict['seq_len']=seq_len
                    save_dict['batch_size']=batch_size
                    save_dict['urdf']=urdf
                    save_dict['weight_decay'] = weight_decay
                    save_dict['num_nn_params'] = num_nn_params

                    torch.save(save_dict,  PATH)

                    print('Training losses at iter ' + 
                        str(training_step) + ': ' + 
                        str(np.round(loss_np,2)))

            del fd_input, actions_in, delta_fd,  loss, optimizer
            torch.cuda.empty_cache()

wnwwnw using 100 out of Rollouts 5000
Num NN params: 495784
Training losses at iter 0: 15.45
Training losses at iter 500: -77.04
Training losses at iter 1000: -94.33
Training losses at iter 1500: -107.5
Training losses at iter 2000: -105.14
Training losses at iter 2500: -117.26
Training losses at iter 3000: -130.7
Training losses at iter 3500: -124.73
Training losses at iter 4000: -125.47
Training losses at iter 4500: -139.18
Training losses at iter 5000: -132.45
Training losses at iter 5500: -146.93
Training losses at iter 6000: -149.89
Training losses at iter 6500: -155.27
Training losses at iter 7000: -160.98
Training losses at iter 7500: -164.75
Training losses at iter 8000: -165.3
Training losses at iter 8500: -170.19
Training losses at iter 9000: -165.68
Training losses at iter 9500: -168.99
Training losses at iter 10000: -172.62
Training losses at iter 10500: -180.19
Training losses at iter 11000: -180.25
Training losses at iter 11500: -182.58
Training losses at iter 12000: -182

Training losses at iter 12500: -322.71
Training losses at iter 13000: -327.61
Training losses at iter 13500: -328.43
Training losses at iter 14000: -291.37
Training losses at iter 14500: -345.26
LR: 0.0005
Training losses at iter 15000: -349.55
Training losses at iter 15500: -362.92
Training losses at iter 16000: -369.92
Training losses at iter 16500: -371.85
Training losses at iter 17000: -366.89
Training losses at iter 17500: -373.8
Training losses at iter 18000: -379.6
Training losses at iter 18500: -372.13
Training losses at iter 19000: -378.35
Training losses at iter 19500: -383.72
LR: 0.00025
Training losses at iter 20000: -374.93
Training losses at iter 20500: -393.7
Training losses at iter 21000: -394.52
Training losses at iter 21500: -398.46
Training losses at iter 22000: -398.65
Training losses at iter 22500: -398.8
Training losses at iter 23000: -395.93
Training losses at iter 23500: -404.64
Training losses at iter 24000: -398.24
Training losses at iter 24500: -406.12
LR: 0.

In [27]:
if RUN_MULTI_GNN:
    # train P-GNN, with multiple designs. this will be N_des times slower
    for condition_tuple in condition_tuples:
        n_rollouts_to_use = condition_tuple[0]
        seq_len = condition_tuple[1]


        comment_str = '_pgnn_multi_' + str(n_rollouts_to_use) + '_' + str(seq_len)
        writer = SummaryWriter(log_dir = os.path.join(folder,'runs',
                    start_time_str+ '_' + str(n_rollouts_to_use) + '_' + str(seq_len)),
                    comment=comment_str)


        # depending on the length of the multistep sequence we want,
        # only some indexes of the full set of states collected can be sampled.
        sampleable_inds = dict()
        batch_sizes = dict()
        for urdf in urdf_names:
            sampleable_inds[urdf] = get_sampleable_inds(run_lens[urdf][:n_rollouts_to_use], seq_len)
            n_sampleable = len(sampleable_inds[urdf])
            batch_sizes[urdf] = batch_size_default
            if batch_sizes[urdf] > n_sampleable:
                batch_sizes[urdf] = n_sampleable
            print(urdf + ' using ' + str(n_rollouts_to_use) + ' out of Rollouts ' + str(len(run_lens[urdf])))


        # initialize network and optimizer
        internal_state_len = 100
        message_len = 50
        hidden_layer_size = 250
        weight_decay = 1e-4
        gnn_nodes = pgnn.create_GNN_nodes(internal_state_len, message_len, hidden_layer_size, 
                        device, body_input = True)
        optimizer = torch.optim.Adam(pgnn.get_GNN_params_list(gnn_nodes), 
                                     lr=1e-3,
                            weight_decay= weight_decay)# create module containers for the nodes


        modules = dict()
        for urdf in urdf_names:
            modules[urdf] = []
            n_modules = len(modules_types[urdf])
            for i in range(n_modules):
                modules[urdf].append(pgnn.Module(i, gnn_nodes[modules_types[urdf][i]], device))

        num_nn_params=0
        for p in pgnn.get_GNN_params_list(gnn_nodes):
            nn=1
            for s in list(p.size()):
                nn = nn*s
            num_nn_params += nn
        print('Num NN params: ' + str(num_nn_params))


        for training_step in range(n_training_steps):

            if np.mod(training_step,5000 )==0 and training_step>10000:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr']/2
                    print( 'LR: ' + str(param_group['lr']) )

            # accumulate gradients accros designs but not loss
            optimizer.zero_grad()
            loss_tot_np = 0
            for urdf in urdf_names:
                batch_size = batch_sizes[urdf]

                # sample without replacement from the full memory, depending on what is sampleable
                state_seq, action_seq, sampled_inds = sample_memory(
                                states_memory_tensors[urdf], actions_memory_tensors[urdf],
                                sampleable_inds[urdf], seq_len, batch_size)

                loss = 0 # accumulate loss for a single design accross the multistep sequence
                state_approx = to_device(state_seq[0],device) # initial state input is the first in sequence
                for seq in range(seq_len-1): # for multistep loss, go through the sequence

                    for module in modules[urdf]: # must reset module lstm state
                        module.reset_hidden_states(batch_size) 

                    # process states to move them to vehicle frame
                    fd_input_real, delta_fd_real = to_body_frame_batch(state_seq[seq], state_seq[seq+1])
                    fd_input_approx, R_t = state_to_fd_input(state_approx) # for recursive estimation

                    # pass through network
                    fd_input   = to_device(fd_input_approx, device) 
                    actions_in = to_device(action_seq[seq], device)
                    delta_fd   = to_device(delta_fd_real, device) 
                    node_inputs = [torch.cat([s,a],1) for (s,a) in zip(fd_input, actions_in)]
                    state_delta_est_mean, state_delta_var = pgnn.run_propagations(
                        modules[urdf], attachments[urdf], 2, node_inputs, device)

                    # compute loss for this step in sequence
                    for mm in range(len(state_delta_est_mean)):
                        loss += torch.sum(
                            (state_delta_est_mean[mm] - delta_fd[mm])**2/state_delta_var[mm] + 
                            torch.log(state_delta_var[mm]) 
                                        )/batch_size/(seq_len-1)

                    # transform back to world frame advance to next sequence step
                    if seq_len>2:
                        # update recursive state estimation for multistep loss
                        # GNN output is already divided up into modules
                        delta_fd_approx = state_delta_est_mean
                        state_approx = from_body_frame_batch(state_approx, delta_fd_approx)

                # after multistep sequence, add loss for this design onto full loss for tracking
                loss_np=(loss.detach().cpu().numpy())
                loss_tot_np += loss_np

                # backward for each design to keep compute tree smaller
                loss.backward()

            # optimizer step once we have accumulated grads for all designs
            optimizer.step()
            writer.add_scalar('Train' + '/Loss_pgnn_multidesign', loss_tot_np, training_step)


            # periodically save the model
            if np.mod(training_step,100)==0:
                PATH = ('learned_models/' + 'multidesign_pgnn_r' + str(int(n_rollouts_to_use)) + 
                        '_ms'+ str(int(seq_len))+'.pt')
                PATH = os.path.join(cwd, PATH)
                gnn_state_dicts=(pgnn.get_state_dicts(gnn_nodes))
                save_dict = dict()
                save_dict['gnn_state_dicts'] =  gnn_state_dicts
                save_dict['internal_state_len'] = gnn_nodes[0].internal_state_len
                save_dict['message_len'] = gnn_nodes[0].message_len
                save_dict['hidden_layer_size'] = gnn_nodes[0].hidden_layer_size
                save_dict['n_rollouts_to_use']=n_rollouts_to_use
                save_dict['seq_len']=seq_len
                save_dict['batch_sizes']=batch_sizes
                save_dict['urdf_names']=urdf_names
                save_dict['weight_decay'] = weight_decay
                save_dict['num_nn_params'] = num_nn_params
                torch.save(save_dict,  PATH)

                print('Training losses at iter ' + 
                    str(training_step) + ': ' + 
                    str(np.round(loss_tot_np,2)))

        del fd_input, actions_in, delta_fd
        torch.cuda.empty_cache()

wnwwnw using 100 out of Rollouts 5000
llllll using 100 out of Rollouts 5000
Num NN params: 495784
Training losses at iter 0: 18.63
Training losses at iter 100: -185.51
Training losses at iter 200: -194.93
Training losses at iter 300: -221.23
Training losses at iter 400: -223.74
Training losses at iter 500: -227.19
Training losses at iter 600: -241.83
Training losses at iter 700: -250.51
Training losses at iter 800: -260.44
Training losses at iter 900: -268.68
Training losses at iter 1000: -280.47
Training losses at iter 1100: -278.79
Training losses at iter 1200: -285.35
Training losses at iter 1300: -294.93
Training losses at iter 1400: -280.17
Training losses at iter 1500: -301.6
Training losses at iter 1600: -305.38
Training losses at iter 1700: -315.69
Training losses at iter 1800: -303.07
Training losses at iter 1900: -315.52
Training losses at iter 2000: -315.9
Training losses at iter 2100: -316.38
Training losses at iter 2200: -329.43
Training losses at iter 2300: -329.93
Traini

Training losses at iter 21100: -566.18
Training losses at iter 21200: -559.27
Training losses at iter 21300: -565.07
Training losses at iter 21400: -564.43
Training losses at iter 21500: -570.99
Training losses at iter 21600: -566.94
Training losses at iter 21700: -569.2
Training losses at iter 21800: -567.46
Training losses at iter 21900: -570.51
Training losses at iter 22000: -570.47
Training losses at iter 22100: -575.49
Training losses at iter 22200: -570.08
Training losses at iter 22300: -572.35
Training losses at iter 22400: -566.83
Training losses at iter 22500: -568.51
Training losses at iter 22600: -568.41
Training losses at iter 22700: -568.89
Training losses at iter 22800: -569.21
Training losses at iter 22900: -574.94
Training losses at iter 23000: -572.58
Training losses at iter 23100: -572.22
Training losses at iter 23200: -568.72
Training losses at iter 23300: -565.45
Training losses at iter 23400: -568.57
Training losses at iter 23500: -572.1
Training losses at iter 236

Training losses at iter 12200: -422.06
Training losses at iter 12300: -428.98
Training losses at iter 12400: -425.59
Training losses at iter 12500: -429.32
Training losses at iter 12600: -424.84
Training losses at iter 12700: -426.15
Training losses at iter 12800: -436.92
Training losses at iter 12900: -433.96
Training losses at iter 13000: -420.35
Training losses at iter 13100: -415.9
Training losses at iter 13200: -434.8
Training losses at iter 13300: -426.61
Training losses at iter 13400: -425.59
Training losses at iter 13500: -430.57
Training losses at iter 13600: -432.63
Training losses at iter 13700: -426.77
Training losses at iter 13800: -440.59
Training losses at iter 13900: -433.16
Training losses at iter 14000: -414.77
Training losses at iter 14100: -441.77
Training losses at iter 14200: -439.14
Training losses at iter 14300: -425.4
Training losses at iter 14400: -435.51
Training losses at iter 14500: -443.66
Training losses at iter 14600: -434.51
Training losses at iter 1470

Training losses at iter 3100: -309.25
Training losses at iter 3200: -323.37
Training losses at iter 3300: -324.0
Training losses at iter 3400: -322.56
Training losses at iter 3500: -314.78
Training losses at iter 3600: -327.35
Training losses at iter 3700: -329.76
Training losses at iter 3800: -331.53
Training losses at iter 3900: -319.99
Training losses at iter 4000: -338.86
Training losses at iter 4100: -344.43
Training losses at iter 4200: -338.72
Training losses at iter 4300: -341.79
Training losses at iter 4400: -343.95
Training losses at iter 4500: -347.8
Training losses at iter 4600: -342.36
Training losses at iter 4700: -348.34
Training losses at iter 4800: -344.97
Training losses at iter 4900: -337.05
Training losses at iter 5000: -356.34
Training losses at iter 5100: -350.93
Training losses at iter 5200: -347.23
Training losses at iter 5300: -339.57
Training losses at iter 5400: -358.96
Training losses at iter 5500: -351.05
Training losses at iter 5600: -353.26
Training losse

Training losses at iter 24300: -451.16
Training losses at iter 24400: -450.4
Training losses at iter 24500: -453.31
Training losses at iter 24600: -451.39
Training losses at iter 24700: -457.31
Training losses at iter 24800: -447.51
Training losses at iter 24900: -454.46
LR: 0.000125
Training losses at iter 25000: -453.72
Training losses at iter 25100: -451.72
Training losses at iter 25200: -459.18
Training losses at iter 25300: -456.1
Training losses at iter 25400: -456.64
Training losses at iter 25500: -456.99
Training losses at iter 25600: -451.8
Training losses at iter 25700: -455.15
Training losses at iter 25800: -458.7
Training losses at iter 25900: -455.22
Training losses at iter 26000: -457.56
Training losses at iter 26100: -455.81
Training losses at iter 26200: -455.47
Training losses at iter 26300: -461.21
Training losses at iter 26400: -457.26
Training losses at iter 26500: -453.86
Training losses at iter 26600: -455.76
Training losses at iter 26700: -457.78
Training losses 

In [3]:
# evaluation: load validation data set
states_memory_validation =dict()
actions_memory_validation =dict()
run_lens_validation =dict()
for urdf in urdf_names:
    
    file_names = []
    folder = os.path.join(cwd, 'random_rollouts')
    found = True
    fname_test = os.path.join(folder,urdf+'_random_rollouts_validation.ptx')
    if os.path.isfile(fname_test):
        file_names.append(fname_test)
    print('Found files ')
    print(str(file_names))
    
    states_memory_validation[urdf] = []
    actions_memory_validation[urdf] = []
    run_lens_validation[urdf] = []

    for fname in file_names:
        print('loading ' + fname )
        data_in = torch.load(fname)
        states_memory_validation[urdf] += data_in['states_memory']
        actions_memory_validation[urdf] += data_in['actions_memory']
        run_lens_validation[urdf] += data_in['run_lens']
        del data_in

#     states_memory_tensors[urdf] = [torch.cat(s,0) for s in list(zip(*states_memory)) ]
#     actions_memory_tensors[urdf] = [torch.cat(s,0) for s in list(zip(*actions_memory)) ]

print('loaded and merged data')

Found files 
['random_rollouts/wnwwnw_random_rollouts_validation.ptx']
loading random_rollouts/wnwwnw_random_rollouts_validation.ptx
Found files 
['random_rollouts/llllll_random_rollouts_validation.ptx']
loading random_rollouts/llllll_random_rollouts_validation.ptx
Found files 
['random_rollouts/llwwll_random_rollouts_validation.ptx']
loading random_rollouts/llwwll_random_rollouts_validation.ptx
loaded and merged data


In [7]:
# compute constant prediction baseline:
# If we were to always predict that delta_fd = 0, 
# what would the error be? Used as a baseline.

results_matrix = np.matrix(np.zeros([len(condition_tuples)*3+1, len(urdf_names)]))

const_pred_diff_list = dict()
for i_urdf in range(len(urdf_names)):
    urdf = urdf_names[i_urdf]
    const_pred_diff_list[urdf] = []
    for run_index in range(len(run_lens_validation[urdf])):

        state_seq0 = [s[:-1] for s in states_memory_validation[urdf][run_index]]
        state_seq1 = [s[1:] for s in states_memory_validation[urdf][run_index]]
        action_seq = [a[:-1] for a in actions_memory_validation[urdf][run_index]]

        # process states to move them to vehicle frame
        fd_input_real, delta_fd_real = to_body_frame_batch(state_seq0, state_seq1)
        delta_fd = torch.cat(delta_fd_real,1)

        const_pred_diff_list[urdf].append(delta_fd)
    print(urdf, ' Constant pred baseline:')
    print( torch.abs(torch.cat(const_pred_diff_list[urdf],0)).sum(-1).mean() )

    results_matrix[0,i_urdf] = torch.abs(torch.cat(const_pred_diff_list[urdf],0)).sum(-1).mean().numpy()

print(results_matrix)

wnwwnw  Constant pred baseline:
tensor(7.6640)
llllll  Constant pred baseline:
tensor(12.5311)
llwwll  Constant pred baseline:
tensor(12.6220)
[[ 7.66399145 12.53108692 12.6220026 ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]]


In [20]:
# Evalute MLP accuracy
# for urdf in urdf_names
for i_urdf in range(len(urdf_names)):
    urdf = urdf_names[i_urdf]
    
    for i_condition_tuple in range(len(condition_tuples)):
        condition_tuple = condition_tuples[i_condition_tuple]
        seq_len = condition_tuple[1]
        n_rollouts_to_use = condition_tuple[0]
    
        

        # load network
#         PATH = os.path.join(cwd,'learned_models',
        PATH = os.path.join(cwd,'learned_models/longer_training_3des',
                            urdf + '_pmlp_r' +
                            str(int(n_rollouts_to_use)) + 
                            '_ms'+ str(int(seq_len))+'.pt')

        if os.path.exists(PATH):
            save_dict = torch.load( PATH)#, map_location=lambda storage, loc: storage)
            input_len, output_len = save_dict['fd_network_input_len'], save_dict['fd_network_output_len']
            n_hidden_layers = save_dict['fd_network_n_hidden_layers']
            hidden_layer_size = save_dict['fd_network_hidden_layer_size']

            fd_network = pmlp(input_len = input_len, output_len=output_len,
                n_hidden_layers = n_hidden_layers, hidden_layer_size=hidden_layer_size
                ).to(device)
            fd_network.load_state_dict(save_dict['fd_network_state_dict'])
            fd_network.eval()


            diff_list = []
            diff_list_by_run = []
            for run_index in range(len(run_lens_validation[urdf])):
                with torch.no_grad():
                    state_seq0 = [s[:-1] for s in states_memory_validation[urdf][run_index]]
                    state_seq1 = [s[1:] for s in states_memory_validation[urdf][run_index]]
                    action_seq = [a[:-1] for a in actions_memory_validation[urdf][run_index]]

                    # process states to move them to vehicle frame
                    fd_input_real, delta_fd_real = to_body_frame_batch(state_seq0, state_seq1)

                    # pass through network
                    fd_input = torch.cat(fd_input_real,1).to(device)
                    actions_in = torch.cat(action_seq,1).to(device)
                    delta_fd = torch.cat(delta_fd_real,1).to(device)
                    state_delta_est_mean, state_delta_est_var = fd_network(fd_input, actions_in)

                    diff = (delta_fd - state_delta_est_mean)
                    diff_list_by_run.append(torch.abs(diff).sum(-1).mean())


                diff_list.append(diff)
            diff_list_all = torch.cat(diff_list,0)
            print(urdf + ' ' + str(condition_tuple) + ' MLP differences:')
            diffs_abs = torch.abs(diff_list_all).sum(-1)

            print('Mean: ' + str(diffs_abs.mean().item() ))
            print('Std: ' + str(diffs_abs.std().item()))
            
            results_matrix[i_condition_tuple+1,i_urdf] = diffs_abs.mean().item()
#     print('MLP relative to const baseline mean:')
#     const_pred = torch.abs(torch.cat(const_pred_diff_list[urdf],0)).sum(-1).to(device)
#     diff_rel = diffs_abs/const_pred
#     print('Mean: ' + str(diff_rel.mean().item() ))
#     print('Std: ' + str(diff_rel.std().item()))
#     # print(np.mean(diff_list/const_pred_diff_list[urdf]))

#     print(urdf + ' MLP differences by run:')
#     diffs_sums = torch.stack(diff_list_by_run,-1)
#     print('Mean: ' + str(diffs_sums.mean().item() ))
#     print('Std: ' + str(diffs_sums.std().item()))
del fd_network, fd_input, actions_in, delta_fd
torch.cuda.empty_cache()

wnwwnw (100, 2) MLP differences:
Mean: 7.7614006996154785
Std: 4.886170864105225
wnwwnw (1000, 2) MLP differences:
Mean: 3.7145955562591553
Std: 3.5600602626800537
wnwwnw (5000, 2) MLP differences:
Mean: 2.4433681964874268
Std: 2.4891090393066406
llllll (100, 2) MLP differences:
Mean: 8.790816307067871
Std: 4.93350887298584
llllll (1000, 2) MLP differences:
Mean: 7.702385902404785
Std: 4.226662635803223
llllll (5000, 2) MLP differences:
Mean: 7.818650722503662
Std: 4.077511787414551
llwwll (100, 2) MLP differences:
Mean: 9.042378425598145
Std: 5.316173076629639
llwwll (1000, 2) MLP differences:
Mean: 8.545150756835938
Std: 4.9753193855285645
llwwll (5000, 2) MLP differences:
Mean: 8.220766067504883
Std: 4.607919216156006


In [22]:
# Evalute GNN accuracy for one design
# urdf = urdf_names[0]
for i_urdf in range(len(urdf_names)):
    urdf = urdf_names[i_urdf]
    
    for i_condition_tuple in range(len(condition_tuples)):
        condition_tuple = condition_tuples[i_condition_tuple]
        seq_len = condition_tuple[1]
        n_rollouts_to_use =  condition_tuple[0]

        # load network
#         PATH = os.path.join(cwd,'learned_models', 
        PATH = os.path.join(cwd,'learned_models/longer_training_3des', 
                            urdf + '_pgnn_r' + str(int(n_rollouts_to_use)) + 
                '_ms'+ str(int(seq_len))+'.pt')

    #     PATH = ('learned_models/' + 'multidesign_pgnn_r' + str(int(n_rollouts_to_use)) + 
    #             '_ms'+ str(int(seq_len))+'.pt')

#         print(PATH)

        if os.path.exists(PATH):

            save_dict = torch.load( PATH)#, map_location=lambda storage, loc: storage)
            internal_state_len = save_dict['internal_state_len']
            message_len= save_dict['message_len']
            hidden_layer_size= save_dict['hidden_layer_size']

            gnn_nodes = pgnn.create_GNN_nodes(internal_state_len, message_len, hidden_layer_size, 
                            device, body_input = True)
            pgnn.load_state_dicts(gnn_nodes, save_dict['gnn_state_dicts'])
            for gnn_node in gnn_nodes:
                gnn_node.eval()

            modules = dict()
            modules[urdf] = []
            n_modules = len(modules_types[urdf])
            for i in range(n_modules):
                modules[urdf].append(pgnn.Module(i, gnn_nodes[modules_types[urdf][i]], device))


            diff_list = dict()
            diff_list_by_run =  dict()

            diff_list[urdf] = []
            diff_list_by_run[urdf] = []
            for run_index in range(len(run_lens_validation[urdf])):
                batch_size = run_lens_validation[urdf][run_index]-1
                with torch.no_grad():
                    state_seq0 = [s[:-1] for s in states_memory_validation[urdf][run_index]]
                    state_seq1 = [s[1:]  for s in states_memory_validation[urdf][run_index]]
                    action_seq = [a[:-1] for a in actions_memory_validation[urdf][run_index]]

                    for module in modules[urdf]: # must reset module lstm state
                        module.reset_hidden_states(batch_size) 

                    # process states to move them to vehicle frame
                    fd_input_real, delta_fd_real = to_body_frame_batch(state_seq0, state_seq1)

                    # pass through network
                    fd_input   = to_device(fd_input_real, device) 
                    actions_in = to_device(action_seq, device)
                    node_inputs = [torch.cat([s,a],1) for (s,a) in zip(fd_input, actions_in)]
                    state_delta_est_mean, state_delta_var = pgnn.run_propagations(
                        modules[urdf], attachments[urdf], 2, node_inputs, device)

                    # cat to one tensor and take difference
                    state_delta_est_mean = torch.cat(state_delta_est_mean,-1)
                    delta_fd   = torch.cat(to_device(delta_fd_real, device),-1)
                    diff = (delta_fd - state_delta_est_mean)

                diff_list_by_run[urdf].append(torch.abs(diff).sum(-1).mean())   
                diff_list[urdf].append(diff)

            print('------------------------------')    
            diff_list_all = torch.cat(diff_list[urdf],0)
            print(urdf + ' ' + str(condition_tuple)+ ' GNN differences:')
            diffs_abs = torch.abs(diff_list_all).sum(-1)

            print('Mean: ' + str(diffs_abs.mean().item() ))
            print('Std: ' + str(diffs_abs.std().item()))
            
            results_matrix[i_condition_tuple+1+3,i_urdf] = diffs_abs.mean().item()

#         print(urdf + ' GNN relative to const baseline mean:')
#         const_pred = torch.abs(torch.cat(const_pred_diff_list[urdf],0)).sum(-1).to(device)
#         diff_rel = diffs_abs/const_pred
#         print('Mean: ' + str(diff_rel.mean().item() ))
#         print('Std: ' + str(diff_rel.std().item()))

#         print(urdf + ' GNN differences by run:')
#         diffs_sums = torch.stack(diff_list_by_run[urdf],-1)
#         print('Mean: ' + str(diffs_sums.mean().item() ))
#         print('Std: ' + str(diffs_sums.std().item()))


------------------------------
wnwwnw (100, 2) GNN differences:
Mean: 2.528707265853882
Std: 2.653336524963379
------------------------------
wnwwnw (1000, 2) GNN differences:
Mean: 1.950215220451355
Std: 2.1682562828063965
------------------------------
wnwwnw (5000, 2) GNN differences:
Mean: 1.9581942558288574
Std: 2.107177495956421
------------------------------
llllll (100, 2) GNN differences:
Mean: 4.71076774597168
Std: 3.483875274658203
------------------------------
llllll (1000, 2) GNN differences:
Mean: 5.1610426902771
Std: 3.322169780731201
------------------------------
llllll (5000, 2) GNN differences:
Mean: 4.85939884185791
Std: 3.070873498916626
------------------------------
llwwll (100, 2) GNN differences:
Mean: 4.409536361694336
Std: 3.644829750061035
------------------------------
llwwll (1000, 2) GNN differences:
Mean: 4.543127536773682
Std: 3.427647352218628
------------------------------
llwwll (5000, 2) GNN differences:
Mean: 4.5057759284973145
Std: 3.257702827453

In [8]:
# Evalute GNN accuracy for all designs
for i_condition_tuple in range(len(condition_tuples)):
    condition_tuple = condition_tuples[i_condition_tuple]
    seq_len = condition_tuple[1]
    n_rollouts_to_use = condition_tuple[0]
    
    # load network

#     PATH = os.path.join(cwd,'learned_models', 
    PATH = os.path.join(cwd,'learned_models/longer_training', 
                        'multidesign_pgnn_r' + str(int(n_rollouts_to_use)) + 
            '_ms'+ str(int(seq_len))+'.pt')

    print(PATH)
    save_dict = torch.load( PATH)#, map_location=lambda storage, loc: storage)
    internal_state_len = save_dict['internal_state_len']
    message_len= save_dict['message_len']
    hidden_layer_size= save_dict['hidden_layer_size']

    gnn_nodes = pgnn.create_GNN_nodes(internal_state_len, message_len, hidden_layer_size, 
                    device, body_input = True)
    pgnn.load_state_dicts(gnn_nodes, save_dict['gnn_state_dicts'])
    for gnn_node in gnn_nodes:
        gnn_node.eval()

        
    modules = dict()
    for urdf in urdf_names:
        modules[urdf] = []
        n_modules = len(modules_types[urdf])
        for i in range(n_modules):
            modules[urdf].append(pgnn.Module(i, gnn_nodes[modules_types[urdf][i]], device))


    diff_list = dict()
    diff_list_by_run =  dict()

    for i_urdf in range(len(urdf_names)):
        urdf = urdf_names[i_urdf]
        
        diff_list[urdf] = []
        diff_list_by_run[urdf] = []
        for run_index in range(len(run_lens_validation[urdf])):
            batch_size = run_lens_validation[urdf][run_index]-1
            with torch.no_grad():
                state_seq0 = [s[:-1] for s in states_memory_validation[urdf][run_index]]
                state_seq1 = [s[1:]  for s in states_memory_validation[urdf][run_index]]
                action_seq = [a[:-1] for a in actions_memory_validation[urdf][run_index]]

                for module in modules[urdf]: # must reset module lstm state
                    module.reset_hidden_states(batch_size) 

                # process states to move them to vehicle frame
                fd_input_real, delta_fd_real = to_body_frame_batch(state_seq0, state_seq1)

                # pass through network
                fd_input   = to_device(fd_input_real, device) 
                actions_in = to_device(action_seq, device)
                node_inputs = [torch.cat([s,a],1) for (s,a) in zip(fd_input, actions_in)]
                state_delta_est_mean, state_delta_var = pgnn.run_propagations(
                    modules[urdf], attachments[urdf], 2, node_inputs, device)

                # cat to one tensor and take difference
                state_delta_est_mean = torch.cat(state_delta_est_mean,-1)
                delta_fd   = torch.cat(to_device(delta_fd_real, device),-1)
                diff = (delta_fd - state_delta_est_mean)

            diff_list_by_run[urdf].append(torch.abs(diff).sum(-1).mean())   
            diff_list[urdf].append(diff)

        print('------------------------------')    
        diff_list_all = torch.cat(diff_list[urdf],0)
        print(urdf + ' GNN differences:')
        diffs_abs = torch.abs(diff_list_all).sum(-1)

        print('Mean: ' + str(diffs_abs.mean().item() ))
        print('Std: ' + str(diffs_abs.std().item()))
        
        results_matrix[i_condition_tuple+1+3+3,i_urdf] = diffs_abs.mean().item()

#         print(urdf + ' GNN relative to const baseline mean:')
#         const_pred = torch.abs(torch.cat(const_pred_diff_list[urdf],0)).sum(-1).to(device)
#         diff_rel = diffs_abs/const_pred
#         print('Mean: ' + str(diff_rel.mean().item() ))
#         print('Std: ' + str(diff_rel.std().item()))

#         print(urdf + ' GNN differences by run:')
#         diffs_sums = torch.stack(diff_list_by_run[urdf],-1)
#         print('Mean: ' + str(diffs_sums.mean().item() ))
#         print('Std: ' + str(diffs_sums.std().item()))

print(results_matrix)

learned_models/longer_training/multidesign_pgnn_r100_ms2.pt
------------------------------
wnwwnw GNN differences:
Mean: 2.4028401374816895
Std: 2.749899387359619
------------------------------
llllll GNN differences:
Mean: 5.068259239196777
Std: 3.53294038772583
------------------------------
llwwll GNN differences:
Mean: 6.886737823486328
Std: 3.877237558364868
learned_models/longer_training/multidesign_pgnn_r1000_ms2.pt
------------------------------
wnwwnw GNN differences:
Mean: 2.0955729484558105
Std: 2.3948006629943848
------------------------------
llllll GNN differences:
Mean: 5.254051685333252
Std: 3.3526487350463867
------------------------------
llwwll GNN differences:
Mean: 5.347886562347412
Std: 3.459967613220215
learned_models/longer_training/multidesign_pgnn_r5000_ms2.pt
------------------------------
wnwwnw GNN differences:
Mean: 2.239934206008911
Std: 2.4226791858673096
------------------------------
llllll GNN differences:
Mean: 5.303031921386719
Std: 3.29570007324218

In [6]:
# condition_tuple = condition_tuples[0]
# n_rollouts_to_use = condition_tuple[0]
# seq_len = condition_tuple[1]
# # depending on the length of the multistep sequence we want,
# # only some indexes of the full set of states collected can be sampled.
# sampleable_inds = dict()
# batch_sizes = dict()
# #     for urdf in urdf_names:
# sampleable_inds[urdf] = get_sampleable_inds(run_lens[urdf][:n_rollouts_to_use], seq_len)
# n_sampleable = len(sampleable_inds[urdf])
# batch_sizes[urdf] = batch_size_default
# if batch_sizes[urdf] > n_sampleable:
#     batch_sizes[urdf] = n_sampleable
# print(urdf + ' using ' + str(n_rollouts_to_use) + ' out of Rollouts ' + str(len(run_lens[urdf])))


# sampled_inds = sampleable_inds[urdf][np.random.choice(len(sampleable_inds[urdf]), 10, replace=False)]
# sampled_ranges = sampled_inds.repeat((seq_len,1))
# for si in range(seq_len):
#     sampled_ranges[si] += si
# print(sampled_ranges)

llllll using 100 out of Rollouts 5000
tensor([[1220, 6710, 4214, 6770, 3124, 4047, 1486,  426, 4785, 6434],
        [1221, 6711, 4215, 6771, 3125, 4048, 1487,  427, 4786, 6435]])
