In [None]:
import os
import json

path = os.getcwd()
os.chdir('/mnt/diskSustainability/frederic/sony_RL/sony_RL/base_functions')
from dm_env_sphere import SphereEnv
os.chdir(path)

import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import warnings
warnings.filterwarnings("ignore")

##

list_holes = []
objects_path = []
object_name = []

for k in range (10):
    if k==3:
        continue
    path_k = f'/mnt/diskSustainability/frederic/scanner-gym_models_v2/random_spheres/random_sphere_{k}/'
    objects_path.append(path_k)
    object_name.append(f'random_sphere_{k}.obj')
    list_holes.append(json.load(open(path_k + f'random_sphere_{k}.json')))

env = SphereEnv(
                objects_path, object_name, img_shape=128, list_holes=list_holes, rmax_T=0.9, max_T=50, theta_n_positions=8, 
                continuous=True
                )
env_test = SphereEnv(
                    objects_path, object_name, img_shape=128, list_holes=list_holes, rmax_T=0.9, max_T=50, theta_n_positions=8, 
                    continuous=True,
                     )
ts = env.reset()

##

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'

os.chdir('/mnt/diskSustainability/frederic/sony_RL/sony_RL/rl') 
import vae
from sac import SAC
from sac_ae import SAC_AE
from dqn import DQN
from base_trainer import Trainer
os.chdir(path)

input_size = 128
filter_sizes = [16,32,64,128]
input_channels = output_channels = 1
final_activation = jax.jit(lambda s:s)

latent_dim = 14
lambda_kl = 9.21e-5

def classic_vae(s, is_training):
    return vae.VAE(input_size, latent_dim, filter_sizes, output_channels, final_activation, coord_conv=True)(s, is_training)

print('##### VAE initialization #####')

vae_init, vae_apply  = hk.without_apply_rng(hk.transform_with_state(classic_vae))
vae_apply_jit = jax.jit(vae_apply, static_argnums=3)

weights = jnp.load('/mnt/diskSustainability/frederic/sony_RL/params_vae_lat=14_kl=9.21e-05.npz', allow_pickle=True)
params_vae = weights['params_vae'][()]
bn_vae_state =  weights['bn_vae_state'][()]

print()
print('##### Initialization finished #####')

##

seed = np.random.randint(100)
print()
print('seed = {}'.format(seed))
print()
print('##### Agent initialization #####')
encoder = (vae_apply_jit, params_vae, bn_vae_state)
gamma = 0.6
scale_reward = 5
agent_params = {'gamma':gamma}

agent = SAC_AE(num_agent_steps=10**6, state_space=np.empty(env.observation_shape), action_space=np.empty((1,1)), 
           seed=seed, start_steps=10**3, gamma=gamma, buffer_size=10**3, batch_size=32, encoder=encoder, scale_reward=scale_reward,
           beta=lambda_kl)

'''agent = DQN(num_agent_steps=10**6, state_space=np.empty(env.observation_shape), action_space=np.array(list(env.actions.keys())), 
           seed=seed, start_steps=10**3, gamma=gamma, buffer_size=10**3, batch_size=32, encoder=encoder, use_goal=True)
'''
print()
print('##### Initialization finished #####')
print()
print('##### Training RL agent #####')

log_dir = 'sac_joint_vae_10env_logs/'

trainer = Trainer(
        env=env,
        env_test=env_test,
        algo=agent,
        log_dir=log_dir,
        num_agent_steps=10**6,
        action_repeat=1,
        eval_interval=10**3,
        save_params=True,
        save_interval=10**4
    )

with open(os.path.join('/mnt/diskSustainability/frederic/sony_RL/', log_dir, 'hyperparameters.json'), 'w') as f:
    json.dump(agent_params, f)

trainer.train()

#agent.load_params('/mnt/diskSustainability/frederic/sony_RL/'+ log_dir+ '/param/step1000000')
print()
print('##### Training finished #####')

In [None]:
j = 0
ts = env_test.reset(obj=j)
print(env_test.current_obj)
M = env_test.current_spc.neigh_ijk
rewards = []
for k in range (10):
    action = agent.select_action(ts.observation)
    ts = env_test.step(action)
    rewards.append(ts.reward)
    if ts.step_type == 2:
        print('finished in {} steps'.format(k+1))
        break
print(rewards)
print(env_test.total_reward)

In [None]:
angles = np.array(env_test.visited_positions)
theta = angles[:,0]
phi = angles[:,1]

if not env_test.continuous:
    theta = (theta+1)*np.pi/(2*env_test.theta_n_positions)
    phi = phi*2*np.pi/env_test.phi_n_positions

R = 5
a = R*np.sin(theta)*np.cos(phi)
b = R*np.sin(theta)*np.sin(phi)
c = R*np.cos(theta)

a = (a+4.97)/0.4
b = (b+4.97)/0.4
c = (c+4.97)/0.4

In [None]:
l = []
for i in range (len(a)):
    l.append([i/len(a),'rgb'+str(plt.get_cmap('jet', len(a))(i,bytes=True)[:3])])
    l.append([(i+1)/len(a),'rgb'+str(plt.get_cmap('jet', len(a))(i,bytes=True)[:3])])

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=a,y=b,z=c,marker=dict(
        color=np.arange(len(a)),
        colorscale=l,                
        colorbar=dict(thickness=20,title={
        'text': 'Timesteps','side':'bottom'},
           tick0=0,dtick=1,x=0.8, y=0.4, len=0.75)),
        text=[str(k) for k in range(len(a))],hoverinfo='text',showlegend=False))
fig.add_trace(go.Scatter3d(x=M[:,0],y=M[:,1],z=M[:,2],mode='markers',showlegend=False))
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0),hovermode='closest', width=700, height=450,title={
        'text': f'Trajectory using SAC(env={j})',
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'}
           )
fig.show()