In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

path = os.getcwd()
os.chdir('/mnt/diskSustainability/frederic/sony_RL/sony_RL/base_functions')
from dm_env_sphere import SphereEnv
from utils import get_mask_background

os.chdir(path)

import optuna
from acme import specs
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy
import json
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore")

# Sphere creation

In [None]:
import openpyscad as ops
import subprocess
import urllib
import json
import numpy as np

In [None]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

def scad_to_stl_to_obj(path_k, obj_name):
    # transform scad to stl
    subprocess.run(['/usr/bin/openscad','-o', os.path.join(path_k, obj_name + '.stl'), os.path.join(path_k, obj_name + '.scad')])

    # transform stl to obj using blender
    localhost = "http://localhost:5000/" 
    url_part = "stlobj?stl=" + os.path.join(path_k, obj_name + '.stl') +"&obj=" + os.path.join(path_k, obj_name + '.obj') 
    + "&forward=X" + "&up=Z"
    urllib.request.urlopen(localhost + url_part).read()

def random_sphere_creation(n_spheres, n_holes, path):
    for k in range (n_spheres):
        theta = np.random.randint(0, 90, n_holes)
        phi = np.random.randint(0, 360, n_holes)
        s = ops.Sphere(r=5,_fn=100)
        c = ops.Cylinder(r=0.5,h=12,center=True, _fn=100)
        c = c.rotate([0,theta[0],phi[0]])
        d = ops.Difference()
        d.append(s)
        d.append(c)
        for i in range (1,n_holes):
            c = ops.Cylinder(r=0.5,h=12,center=True, _fn=100)
            c = c.rotate([0,theta[i],phi[i]])
            temp_d = d
            d = ops.Difference()
            d.append(temp_d)
            d.append(c)

        obj_name = f'random_sphere_{k}'
        
        path_k = os.path.join(path, obj_name)
        if not os.path.exists(path_k):
            os.mkdir(path_k)
        d.write(os.path.join(path_k, obj_name + '.scad'))
        with open(os.path.join(path_k, obj_name + '.json'), 'w') as f:
            json.dump(list(zip(theta*np.pi/180,phi*np.pi/180)), f, cls=NpEncoder)
        
        scad_to_stl_to_obj(path_k, obj_name)

        subprocess.run(['cp','/mnt/diskSustainability/frederic/scanner-gym_models_v2/sphere_hole/camera_model.json',
                    path_k])
        subprocess.run(['cp','/mnt/diskSustainability/frederic/scanner-gym_models_v2/sphere_hole/bbox_general.json',
                        path_k])
        subprocess.run(['cp','/mnt/diskSustainability/frederic/scanner-gym_models_v2/sphere_hole/params.json',
                        path_k])
        

def random_plant_creation(n_plants, n_branches, path):
    for k in range (n_plants):
        theta = np.random.randint(0, 90, n_branches)
        phi = np.random.randint(0, 360, n_branches)
        u = ops.Union()
        for i in range (n_branches):
            if i == 0:
                c = ops.Cylinder(r=0.5, h=12, center=True, _fn=10)
            else:
                c = ops.Cylinder(r=0.5, h=6, _fn=10) 
            d = ops.Circle(r=1, _fn=4)
            d = d.linear_extrude(height=1,center=True, _fn=10)
            d = d.rotate([0,90,0])
            d = d.translate([0,0,7.01])
            e = c+d
            e = e.rotate([0, theta[i], phi[i]])
            u.append(e)
        
        obj_name = f'random_plant_{k}'
        
        path_k = os.path.join(path, obj_name)
        if not os.path.exists(path_k):
            os.mkdir(path_k)
        u.write(os.path.join(path_k, obj_name + '.scad'))
        with open(os.path.join(path_k, obj_name + '.json'), 'w') as f:
            json.dump(list(zip(theta*np.pi/180,phi*np.pi/180)), f, cls=NpEncoder)
        
        scad_to_stl_to_obj(path_k, obj_name)

In [None]:
random_plant_creation(10, 3, '/mnt/diskSustainability/frederic/scanner-gym_models_v2/random_plants')

# environment creation

In [None]:
gt = np.load('/mnt/diskSustainability/frederic/scanner-gym_models_v2/sphere_hole/ground_truth_volumes/gt_0.npy')

x,y,z = np.where(gt==-1)

mask = ((y==12) & (z==12)) | ((y==12) & (z==13)) | ((y==13) & (z==12)) | ((y==13) & (z==13))
mask = mask & (x>4) & (x<21)
mask1 = ((y==z) & (x==12)) | ((y==z) & (x==13))
mask1 = mask1 & (y>4) & (y<21)
mask2 = ((y==25-z) & (x==12)) | ((y==25-z) & (x==13))
mask2 = mask2 & (y>4) & (y<21)

mask = mask | mask1 | mask2

voxel_weights = np.zeros((25,25,25))

for i in range (len(x[mask])):
    voxel_weights[x[mask][i],y[mask][i],z[mask][i]] = 1

In [None]:
list_holes = []
objects_path = []
object_name = []

for k in range (10):
    obj = f'random_plant_{k}'
    path_k = f'/mnt/diskSustainability/frederic/scanner-gym_models_v2/random_plants/' + obj
    objects_path.append(path_k)
    object_name.append(obj + f'.obj')
    list_holes.append(json.load(open(path_k + '/' + obj + '.json')))

env = SphereEnv(objects_path, object_name, list_holes=list_holes, rmax_T=100, max_T=100000, theta_n_positions=90)
env_test = SphereEnv(objects_path, object_name, list_holes=list_holes, rmax_T=100, max_T=100000, theta_n_positions=90)
ts = env.reset()

In [None]:
for k in range(10):
    ts = env.reset(obj=object_name[k])
    print(env.current_obj)
    M = env.current_spc.neigh_ijk
    
    import plotly.graph_objects as go

    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=M[:,0],y=M[:,1],z=M[:,2],mode='markers',showlegend=False))
    fig.show()

Create a gt model by taking many views on a spherical grid. Get the maximum reward obtained (which can be different from 1) for further normalization

In [None]:
for k in range (10):
    ts = env.reset(obj=k)
    for theta in range (env.theta_n_positions):
        for phi in range (env.phi_n_positions):
            ts = env.step_angle((theta+1)*np.pi/(env.theta_n_positions), phi*2*np.pi/env.phi_n_positions)
            if env.done:
                print('done')

    path_k = f'/mnt/diskSustainability/frederic/scanner-gym_models_v2/random_plants/random_plant_{k}/'
    l = json.load(open(path_k + 'params.json')) # do it only once per object to set the maximum reward value
    if 'rmax_inf' not in l['train'].keys():
        l['train']['rmax_inf'] = env.total_reward
        if env.total_reward == 0:
            print(0)
        with open(path_k + 'params.json', 'w') as f:
            json.dump(l, f)
    np.save(path_k + 'gt', env.current_spc.last_volume)

# VAE pretraining using Optuna for hyperparameter search. 
# Several encoder architecture : Simple encoder, ResNet, VQVAE

In [None]:
obs = []
ts = env.reset()
for theta in np.linspace(np.pi/8,np.pi/2,5):
    for phi in np.linspace(0,2*np.pi,100):
        ts = env.step_angle(theta, phi)
        obs.append(env.canny[None])
obs = np.concatenate(obs)
obs = obs*1.
var = np.var(obs)

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'

os.chdir('/mnt/diskSustainability/frederic/sony_RL/sony_RL/sac') 
import vae
import vqvae
os.chdir(path)

import jax
import jax.numpy as jnp
import haiku as hk
import optax

# Classic VAE

In [None]:
num_linear_epochs = 5000
num_batches = 1000
num_epochs = 3000
num_factors = 2
batch_size = 16
batch_size_disantanglement = 64
input_channels = output_channels = 1
input_size = 128
filter_sizes = [16,32,64,128]
final_activation = jax.jit(lambda s:s)

def objective_vae(trial):

    seed = np.random.randint(100)

    latent_dim = trial.suggest_int('latent_dim', 3, 30)
    lambda_kl = trial.suggest_float('lambda_kl', 1e-8, 1, log=True)

    def classic_vae(s, is_training):
        return vae.VAE(input_size, latent_dim, filter_sizes, output_channels, final_activation, coord_conv=True)(s, is_training)

    vae_init, vae_apply  = hk.without_apply_rng(hk.transform_with_state(classic_vae))
    vae_apply_jit = jax.jit(vae_apply, static_argnums=3)
    params_vae, bn_vae_state = vae_init(jax.random.PRNGKey(seed), np.random.uniform(0, 1, (1,input_size,input_size,input_channels)), True)
    optimizer_vae = optax.radam(5e-4)
    opt_vae_state = optimizer_vae.init(params_vae)

    @jax.jit
    def loss_ae(params, bn_state, obs):

        aux, bn_state = vae_apply_jit(params, bn_state, obs, True)
        reconst, latent, mu, log_var = aux

        # MSE for reconstruction errors.
        loss_reconst = jnp.square(obs - reconst).mean()
        loss_reconst = loss_reconst/var

        loss_kl = vae.kl_gaussian(mu, log_var)
        loss_kl = loss_kl * lambda_kl
        
        return loss_reconst + loss_kl, (loss_reconst, loss_kl, bn_state)

    @jax.jit
    def update(params, opt_state, bn_state, obs):

        (loss, aux), grad = jax.value_and_grad(loss_ae, has_aux=True)(params, bn_state, obs)
        updates, new_opt_state = optimizer_vae.update(grad, opt_state)
        new_params = optax.apply_updates(params, updates)

        return new_params, new_opt_state, loss, aux
        
    losses = []
    
    for _ in range(num_epochs):

        idx = np.random.choice(len(obs), batch_size, replace=False)
        batch_obs = jnp.array(obs[idx])
        params_vae, opt_vae_state, loss, aux = update(params_vae, opt_vae_state, bn_vae_state, batch_obs)
        loss_reconst, loss_kl, bn_vae_state = aux
        losses.append([loss_reconst.item(), loss_kl.item(), loss.item()])

    losses = np.array(losses)

    #calculate disantanglement metric
    X, y = vae.disantanglement_data(vae_apply_jit, params_vae, bn_vae_state, num_batches, batch_size_disantanglement, num_factors, input_size, input_size)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
    
    clf = LogisticRegression()
    clf.fit(X_train, y_train)
    accuracy = clf.score(X_test, y_test)

    fig, ax = plt.subplots(1,2,figsize=(15,5))
    ax[0].plot(losses[:,0],label='reconst')
    ax[0].legend()
    ax[1].plot(losses[:,1],label='kl')
    ax[1].legend()

    fig, ax = plt.subplots(1,2,figsize=(15,5))
    idx = np.random.choice(len(obs))
    ax[0].imshow(obs[idx], cmap='gray')
    aux, _ = vae_apply_jit(params_vae, bn_vae_state, obs[idx][None], False)
    reconst, latent, mu, log_var = aux
    ax[1].imshow(reconst[0], cmap='gray')
    plt.show()

    '''plt.figure(figsize=(15,3))
    plt.plot(A)

    plt.figure(figsize=(15,3))
    plt.plot(L)'''

    return accuracy

study = optuna.create_study(direction='maximize')  # Create a new study.
study.optimize(objective_vae, n_trials=20, show_progress_bar=True)

# VQVAE

In [None]:
def objective_vqvae(trial):

    input_size = 128
    num_epochs = 3000
    batch_size = 16
    decay = 0.99
    input_channels = output_channels = 1
    seed = np.random.randint(100)
    
    num_hiddens = 32
    num_residual_hiddens = 32
    num_residual_layers = 2

    latent_dim = trial.suggest_int('latent_dim', 3, 10)
    commitment_cost = trial.suggest_float('commitment_cost', 1e-5, 1e2, log=True)
    #lambda_decoder = trial.suggest_float('lambda_decoder', 1e-8, 1e-2, log=True)
    #num_embeddings = trial.suggest_int('num_embeddings', 64, 512)
    num_embeddings = 512

    def VQVAE(s, is_training):
        return vqvae.VQVAE(latent_dim, num_embeddings, num_hiddens, num_residual_layers, num_residual_hiddens, commitment_cost, decay, output_channels)(s, is_training)

    vae_init, vae_apply  = hk.without_apply_rng(hk.transform_with_state(VQVAE))
    vae_apply_jit = jax.jit(vae_apply, static_argnums=3)
    params_vae, bn_vae_state = vae_init(jax.random.PRNGKey(seed), np.random.uniform(0, 1, (1,input_size,input_size,input_channels)), True)
    optimizer_vae = optax.adam(5e-4)
    opt_vae_state = optimizer_vae.init(params_vae)

    @jax.jit
    def loss_ae(params, bn_state, obs):

        vq_output, bn_state = vae_apply_jit(params, bn_state, obs, True)
        vq_loss = vq_output['vq_loss']
        loss_reconst = vq_output['recon_error']
        loss_reconst = loss_reconst/var

        '''decoder_params = {k:params[k] for k in params.keys() if 'decoder' in k}
        leaves, _ = jax.tree_flatten(decoder_params)
        loss_decoder = lambda_decoder * jnp.linalg.norm(jax.flatten_util.ravel_pytree(leaves)[0])**2'''
        loss_latent = 0
        loss_decoder = 0

        return vq_loss + loss_reconst + loss_decoder + loss_latent, (vq_loss, loss_reconst, loss_decoder, loss_latent, bn_state)
    
    def update(params, opt_state, bn_state, obs):
        (loss, aux), grad = jax.value_and_grad(loss_ae, has_aux=True)(params, bn_state, obs)
        updates, new_opt_state = optimizer_vae.update(grad, opt_state)
        new_params = optax.apply_updates(params, updates)

        return new_params, new_opt_state, loss, aux

    update_jit = jax.jit(update)
    
    losses = []
    
    for k in range(num_epochs):

        idx = np.random.choice(len(obs), batch_size)
        batch_obs = jnp.array(obs[idx])
        params_vae, opt_vae_state, loss, aux = update_jit(params_vae, opt_vae_state, bn_vae_state, batch_obs)
        vq_loss, loss_reconst, loss_decoder, loss_latent, bn_vae_state = aux
        losses.append([loss_reconst.item(), vq_loss.item(), loss_decoder.item(), loss_latent.item(), loss.item()])

    losses = np.array(losses)
    fig, ax = plt.subplots(2,2,figsize=(15,5))
    ax[0,0].plot(losses[:,0],label='reconst')
    ax[0,0].legend()
    ax[0,1].plot(losses[:,1],label='vq')
    ax[0,1].legend()
    ax[1,0].plot(losses[:,2],label='decoder')
    ax[1,0].legend()
    ax[1,1].plot(losses[:,3],label='latent')
    ax[1,1].legend()

    
    fig, ax = plt.subplots(1,2,figsize=(15,5))
    idx = np.random.choice(len(obs))
    ax[0].imshow(obs[idx], cmap='gray')
    vq_output, _ = vae_apply_jit(params_vae, bn_vae_state, obs[idx][None], False)
    reconst = vq_output['x_recon']
    ax[1].imshow(reconst[0], cmap='gray')
    plt.show()

    return losses[-1,0]

study = optuna.create_study()  # Create a new study.
study.optimize(objective_vqvae, n_trials=15, show_progress_bar=True)

# CatVAE

In [None]:
def objective_vae(trial):

    input_size = 128
    num_epochs = 3000
    batch_size = 16
    seed = np.random.randint(100)
    filter_sizes = [16,32,64,128]
    input_channels = output_channels = 1

    latent_dim = trial.suggest_int('latent_dim', 3, 40)
    lambda_kl = trial.suggest_float('lambda_kl', 1e-8, 1e-4, log=True)
    temp = trial.suggest_float('temp', 1e-3, 1e3, log=True)
    prior =  trial.suggest_float('prior', 1e-3, 1e-1, log=True)
    num_classes = 2 
    final_activation = jax.jit(lambda s:s)
    #lambda_decoder = trial.suggest_float('lambda_decoder', 1e-8, 1e-5, log=True)
    #lambda_latent = trial.suggest_float('lambda_latent', 1e-8, 1e-5, log=True)

    def cat_vae(s, is_training, T):
        return vae.CatVAE(input_size, latent_dim, num_classes, filter_sizes, output_channels, final_activation)(s, is_training, T)

    vae_init, vae_apply  = hk.without_apply_rng(hk.transform_with_state(cat_vae))
    vae_apply_jit = jax.jit(vae_apply, static_argnums=3)
    params_vae, bn_vae_state = vae_init(jax.random.PRNGKey(seed), np.random.uniform(0, 1, (1,input_size,input_size,input_channels)), True, temp)
    optimizer_vae = optax.radam(5e-4)
    opt_vae_state = optimizer_vae.init(params_vae)

    @jax.jit
    def loss_ae(params, bn_state, obs, T, eps = 1e-7):

        aux, bn_state = vae_apply_jit(params, bn_state, obs, True, T)
        reconst, latent, q = aux
        q_p = jax.nn.softmax(q, axis=-1)

        # MSE for reconstruction errors.
        loss_reconst = jnp.square(obs - reconst).mean()
        loss_reconst = loss_reconst/var

        loss_decoder = 0
        loss_latent = 0

        h1 = q_p * jnp.log(q_p + eps)

        # Cross entropy with the categorical distribution
        h2 = q_p * jnp.log(prior + eps)
        loss_kl = jnp.mean(jnp.sum(h1 - h2, axis =(1,2)), axis=0)
        loss_kl = lambda_kl*loss_kl
       
        return loss_reconst + loss_kl + loss_decoder + loss_latent, (loss_reconst, loss_kl, loss_decoder, loss_latent, bn_state)
    
    def update(params, opt_state, bn_state, obs, T):

        (loss, aux), grad = jax.value_and_grad(loss_ae, has_aux=True)(params, bn_state, obs, T)
        updates, new_opt_state = optimizer_vae.update(grad, opt_state)
        new_params = optax.apply_updates(params, updates)

        return new_params, new_opt_state, loss, aux

    update_jit = jax.jit(update)
    
    losses = []
    
    for k in range(1,num_epochs+1):
        idx = np.random.choice(len(obs), batch_size)
        batch_obs = jnp.array(obs[idx])
        params_vae, opt_vae_state, loss, aux = update_jit(params_vae, opt_vae_state, bn_vae_state, batch_obs, temp)
        loss_reconst, loss_kl, loss_decoder, loss_latent, bn_vae_state = aux
        losses.append([loss_reconst.item(), loss_kl.item(), loss_decoder.item(), loss_latent.item(), loss.item()])

    losses = np.array(losses)
    fig, ax = plt.subplots(2,2,figsize=(15,5))
    ax[0,0].plot(losses[:,0],label='reconst')
    ax[0,0].legend()
    ax[0,1].plot(losses[:,1],label='kl')
    ax[0,1].legend()
    ax[1,0].plot(losses[:,2],label='decoder')
    ax[1,0].legend()
    ax[1,1].plot(losses[:,3],label='latent')
    ax[1,1].legend()


    fig, ax = plt.subplots(1,2,figsize=(15,5))
    idx = np.random.choice(len(obs))
    ax[0].imshow(obs[idx], cmap='gray')
    aux, _ = vae_apply_jit(params_vae, bn_vae_state, obs, False, temp)
    reconst, latent, q_p = aux
    ax[1].imshow(reconst[0], cmap='gray')
    plt.show()

    return losses[-1,0]

study = optuna.create_study()  # Create a new study.
study.optimize(objective_vae, n_trials=30, show_progress_bar=True)

# SAC

## Initialize the autoencoder

In [None]:
input_size = 128
num_epochs = 5000
batch_size = 16
seed = np.random.randint(100)
filter_sizes = [16,32,64,128]
input_channels = output_channels = 1
final_activation = jax.jit(lambda s:s)

latent_dim = 14

def classic_vae(s, is_training):
    return vae.VAE(input_size, latent_dim, filter_sizes, output_channels, final_activation, coord_conv=True)(s, is_training)

vae_init, vae_apply  = hk.without_apply_rng(hk.transform_with_state(classic_vae))
vae_apply_jit = jax.jit(vae_apply, static_argnums=3)

In [None]:
input_size = 128
num_epochs = 3000
batch_size = 16
decay = 0.99
input_channels = output_channels = 1
seed = np.random.randint(100)

num_hiddens = 32
num_residual_hiddens = 32
num_residual_layers = 2

latent_dim = 5
commitment_cost = 0.00044

num_embeddings = 512

def VQVAE(s, is_training):
    return vqvae.VQVAE(latent_dim, num_embeddings, num_hiddens, num_residual_layers, num_residual_hiddens, commitment_cost, decay, output_channels)(s, is_training)

vae_init, vae_apply  = hk.without_apply_rng(hk.transform_with_state(VQVAE))
vae_apply_jit = jax.jit(vae_apply, static_argnums=3)

## Pretrain the VAE

Classic VAE

In [None]:
lambda_kl = 9.21e-5

params_vae, bn_vae_state = vae_init(jax.random.PRNGKey(seed), np.random.uniform(0, 1, (1,input_size,input_size,input_channels)), True)
optimizer_vae = optax.radam(5e-4)
opt_vae_state = optimizer_vae.init(params_vae)

@jax.jit
def loss_ae(params, bn_state, obs):

    aux, bn_state = vae_apply_jit(params, bn_state, obs, True)
    reconst, latent, mu, log_var = aux

    # MSE for reconstruction errors.
    loss_reconst = jnp.square(obs - reconst).mean()
    loss_reconst = loss_reconst/var

    loss_decoder = 0
    loss_latent = 0

    loss_kl = vae.kl_gaussian(mu, log_var)
    loss_kl = loss_kl * lambda_kl
    
    return loss_reconst + loss_kl + loss_decoder + loss_latent, (loss_reconst, loss_kl, loss_decoder, loss_latent, bn_state)

def update(params, opt_state, bn_state, obs):

    (loss, aux), grad = jax.value_and_grad(loss_ae, has_aux=True)(params, bn_state, obs)
    updates, new_opt_state = optimizer_vae.update(grad, opt_state)
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss, aux

update_jit = jax.jit(update)

losses = []

for k in range(num_epochs):

    idx = np.random.choice(len(obs), batch_size)
    batch_obs = jnp.array(obs[idx])
    params_vae, opt_vae_state, loss, aux = update_jit(params_vae, opt_vae_state, bn_vae_state, batch_obs)
    loss_reconst, loss_kl, loss_decoder, loss_latent, bn_vae_state = aux
    losses.append([loss_reconst.item(), loss_kl.item(), loss_decoder.item(), loss_latent.item(), loss.item()])

losses = np.array(losses)
fig, ax = plt.subplots(2,2,figsize=(15,5))
ax[0,0].plot(losses[:,0],label='reconst')
ax[0,0].legend()
ax[0,1].plot(losses[:,1],label='kl')
ax[0,1].legend()
ax[1,0].plot(losses[:,2],label='decoder')
ax[1,0].legend()
ax[1,1].plot(losses[:,3],label='latent')
ax[1,1].legend()

jnp.savez('params_vae_lat={}_kl={:.2e}.npz'.format(latent_dim, lambda_kl), params_vae=params_vae, bn_vae_state=bn_vae_state)

VQVAE

In [None]:
params_vae, bn_vae_state = vae_init(jax.random.PRNGKey(seed), np.random.uniform(0, 1, (1,input_size,input_size,input_channels)), True)
optimizer_vae = optax.radam(5e-4)
opt_vae_state = optimizer_vae.init(params_vae)

@jax.jit
def loss_ae(params, bn_state, obs):

    vq_output, bn_state = vae_apply_jit(params, bn_state, obs, True)
    vq_loss = vq_output['vq_loss']
    loss_reconst = vq_output['recon_error']
    loss_reconst = loss_reconst/var

    loss_latent = 0
    loss_decoder = 0

    return vq_loss + loss_reconst + loss_decoder + loss_latent, (vq_loss, loss_reconst, loss_decoder, loss_latent, bn_state)

def update(params, opt_state, bn_state, obs, mask):
    (loss, aux), grad = jax.value_and_grad(loss_ae, has_aux=True)(params, bn_state, obs, mask)
    updates, new_opt_state = optimizer_vae.update(grad, opt_state)
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss, aux

update_jit = jax.jit(update)

losses = []

for k in range(num_epochs):

    idx = np.random.choice(len(obs), batch_size)
    batch_obs = jnp.array(obs[idx])
    params_vae, opt_vae_state, loss, aux = update_jit(params_vae, opt_vae_state, bn_vae_state, batch_obs)
    vq_loss, loss_reconst, loss_decoder, loss_latent, bn_vae_state = aux
    losses.append([loss_reconst.item(), vq_loss.item(), loss_decoder.item(), loss_latent.item(), loss.item()])

losses = np.array(losses)
fig, ax = plt.subplots(2,2,figsize=(15,5))
ax[0,0].plot(losses[:,0],label='reconst')
ax[0,0].legend()
ax[0,1].plot(losses[:,1],label='vq')
ax[0,1].legend()
ax[1,0].plot(losses[:,2],label='decoder')
ax[1,0].legend()
ax[1,1].plot(losses[:,3],label='latent')
ax[1,1].legend()

jnp.savez('params_vqvae_lat={}_lambda={:.2e}.npz'.format(latent_dim, commitment_cost), params_vae=params_vae, bn_vae_state=bn_vae_state)

CatVAE

In [None]:
input_size = 128
num_epochs = 3000
batch_size = 16
seed = np.random.randint(100)
filter_sizes = [16,32,64,128]
input_channels = output_channels = 1

latent_dim = 24
lambda_kl = 3.6644894566489406e-07
temp = 4.911724757502476
prior =  0.005071067050815106
num_classes = 2 
final_activation = jax.jit(lambda s:s)

def cat_vae(s, is_training, T):
    return vae.CatVAE(input_size, latent_dim, num_classes, filter_sizes, output_channels, final_activation)(s, is_training, T)

vae_init, vae_apply  = hk.without_apply_rng(hk.transform_with_state(cat_vae))
vae_apply_jit = jax.jit(vae_apply, static_argnums=3)
params_vae, bn_vae_state = vae_init(jax.random.PRNGKey(seed), np.random.uniform(0, 1, (1,input_size,input_size,input_channels)), True, temp)
optimizer_vae = optax.radam(5e-4)
opt_vae_state = optimizer_vae.init(params_vae)

@jax.jit
def loss_ae(params, bn_state, obs, T, eps = 1e-7):

    aux, bn_state = vae_apply_jit(params, bn_state, obs, True, T)
    reconst, latent, q = aux
    q_p = jax.nn.softmax(q, axis=-1)

    # MSE for reconstruction errors.
    loss_reconst = jnp.square(obs - reconst).mean()
    loss_reconst = loss_reconst/var

    loss_decoder = 0
    loss_latent = 0

    h1 = q_p * jnp.log(q_p + eps)

    # Cross entropy with the categorical distribution
    h2 = q_p * jnp.log(prior + eps)
    loss_kl = jnp.mean(jnp.sum(h1 - h2, axis =(1,2)), axis=0)
    loss_kl = lambda_kl*loss_kl
    
    return loss_reconst + loss_kl + loss_decoder + loss_latent, (loss_reconst, loss_kl, loss_decoder, loss_latent, bn_state)

def update(params, opt_state, bn_state, obs, T):

    (loss, aux), grad = jax.value_and_grad(loss_ae, has_aux=True)(params, bn_state, obs, T)
    updates, new_opt_state = optimizer_vae.update(grad, opt_state)
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss, aux

update_jit = jax.jit(update)

losses = []

for k in range(1,num_epochs+1):
    idx = np.random.choice(len(obs), batch_size)
    batch_obs = jnp.array(obs[idx])
    params_vae, opt_vae_state, loss, aux = update_jit(params_vae, opt_vae_state, bn_vae_state, batch_obs, temp)
    loss_reconst, loss_kl, loss_decoder, loss_latent, bn_vae_state = aux
    losses.append([loss_reconst.item(), loss_kl.item(), loss_decoder.item(), loss_latent.item(), loss.item()])

losses = np.array(losses)
fig, ax = plt.subplots(2,2,figsize=(15,5))
ax[0,0].plot(losses[:,0],label='reconst')
ax[0,0].legend()
ax[0,1].plot(losses[:,1],label='kl')
ax[0,1].legend()
ax[1,0].plot(losses[:,2],label='decoder')
ax[1,0].legend()
ax[1,1].plot(losses[:,3],label='latent')
ax[1,1].legend()

## Or load weights

In [None]:
weights = jnp.load('/mnt/diskSustainability/frederic/sony_RL/params_vae_lat=14_kl=9.21e-05.npz', allow_pickle=True)
params_vae = weights['params_vae'][()]
bn_vae_state =  weights['bn_vae_state'][()]

Check reconstruction images

In [None]:
env.reset()
ts = env.step_angle(0.5, 0.1)
fig, ax = plt.subplots(1,2,figsize=(15,5))
ax[0].imshow(env.canny, cmap='gray')
aux, _ = vae_apply_jit(params_vae, bn_vae_state, env.canny[None]*1., False)
reconst, latent, mu, log_var = aux
'''vq_output, _ = vae_apply_jit(params_vae, bn_vae_state, env.canny[None]*1., False)
reconst = vq_output['x_recon']'''
ax[1].imshow(reconst[0], cmap='gray')
plt.show()

## UMAP clustering

Check interpretability of the latent space

In [None]:
latent = []
pos = []
imgs = []
labels = np.concatenate(np.array([[theta]*100 for theta in np.linspace(np.pi/8,np.pi/2,15)]))
labels_bokeh = []
ts = env.reset()
for theta in np.linspace(np.pi/8,np.pi/2,15):
    for phi in np.linspace(0,2*np.pi,100):
        ts = env.step_angle(theta, phi)
        pos.append(env.pos)
        vae_output, _ = vae_apply_jit(params_vae, bn_vae_state, env.canny[None]*1., False)
        latent.append(vae_output[1][0])
        imgs.append(env.canny[...,0])
        x,y,z = env.pos
        labels_bokeh.append('x={:.2f}, y={:.2f}'.format(x,y))
        
latent = np.array(latent)
pos = np.array(pos)
imgs = np.array(imgs)

In [None]:
import umap
import umap.plot
embedding = umap.UMAP(n_neighbors=20, low_memory=False, n_jobs=-1).fit(latent)
umap.plot.points(embedding, labels=np.around(labels,2))

In [None]:
from io import BytesIO
from PIL import Image
import base64

def embeddable_image(img):
    image = Image.fromarray(img, mode='L')
    buffer = BytesIO()
    image.save(buffer, format='png')
    for_encoding = buffer.getvalue()
    return 'data:image/png;base64,' + base64.b64encode(for_encoding).decode()

from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource

output_notebook()

In [None]:
import pandas as pd
df = pd.DataFrame(embedding.transform(latent), columns=('x', 'y'))
df['image'] = list(map(embeddable_image, imgs))
df['pos'] = labels_bokeh
datasource = ColumnDataSource(df)

plot_figure = figure(
    title='UMAP projection of the encoder output',
    plot_width=600,
    plot_height=600,
    tools=('pan, wheel_zoom, reset')
)

'''plot_figure.add_tools(HoverTool(tooltips="""
<div>
    <div>
        <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
    </div>
    
</div>
"""))'''

plot_figure.add_tools(HoverTool(tooltips="""
<div>
    <div>
        <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
    </div>
    <div>
        <span style='font-size: 16px; color: #224499'>pos:</span>
        <span style='font-size: 18px'>@pos</span>
    </div>
</div>
"""))

plot_figure.circle(
    'x',
    'y',
    source=datasource,
    line_alpha=0.6,
    fill_alpha=0.6,
    size=4
)
show(plot_figure)

## Initialize RL agent

In [None]:
objects_path = '/mnt/diskSustainability/frederic/scanner-gym_models_v2/sphere_hole'
object_name = 'sphere.obj'
env = SphereEnv(objects_path, object_name, img_shape=128, voxel_weights=voxel_weights, rmax=0.7, continuous=True, use_img=True, max_T=50)
env_test = SphereEnv(objects_path, object_name, img_shape=128, voxel_weights=voxel_weights, rmax=0.7, continuous=True, use_img=True, max_T=50)
ts = env.reset()

In [None]:
os.chdir('/mnt/diskSustainability/frederic/sony_RL/sony_RL/sac')
from sac import SAC
#from dqn import DQN
from ddpg import DDPG
from base_trainer import Trainer
os.chdir(path)

seed = np.random.randint(100)
encoder = (vae_apply_jit, params_vae, bn_vae_state)
'''agent = DQN(num_agent_steps=10**6, state_space=np.empty(env.observation_shape), action_space=np.array(list(env.actions.keys())), 
           seed=seed, start_steps=10**3, gamma=0.75, buffer_size=10**3, batch_size=32)
'''

agent = SAC(num_agent_steps=10**6, state_space=np.empty(env.observation_shape), action_space=np.empty((2,2)), 
             seed=seed, start_steps=10**3, gamma=0.75, buffer_size=10**3, batch_size=32, encoder=encoder, scale_reward=5)
#agent.load_params('/mnt/diskSustainability/frederic/sony_RL/sac_beta_coordconv_vae_logs/param/step1000000')

Train it

In [None]:
trainer = Trainer(
        env=env,
        env_test=env_test,
        algo=agent,
        log_dir='sac_vae_scale=5_logs',
        num_agent_steps=10**6,
        action_repeat=1,
        eval_interval=10**3,
        save_params=True,
        save_interval=10**4
    )

trainer.train()

# Tests

Check trajectories obtained

In [None]:
ts = env_test.reset()
rewards = []
for k in range (20):
    action = agent.select_action(ts.observation)
    ts = env_test.step(action)
    rewards.append(ts.reward)
    if ts.step_type == 2:
        print('finished in {} steps'.format(k+1))
        break
print(rewards)

In [None]:
angles = np.array(env_test.visited_positions)
theta = angles[:,0]
phi = angles[:,1]

if not env_test.continuous:
    theta = (theta+1)*np.pi/8
    phi = phi*np.pi/90

R = 5
a = R*np.sin(theta)*np.cos(phi)
b = R*np.sin(theta)*np.sin(phi)
c = R*np.cos(theta)

a = (a+4.97)/0.4
b = (b+4.97)/0.4
c = (c+4.97)/0.4

In [None]:
import plotly.graph_objects as go

l = []
for i in range (len(a)):
    l.append([i/len(a),'rgb'+str(plt.get_cmap('jet', len(a))(i,bytes=True)[:3])])
    l.append([(i+1)/len(a),'rgb'+str(plt.get_cmap('jet', len(a))(i,bytes=True)[:3])])

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=a,y=b,z=c,marker=dict(
        color=np.arange(len(a)),
        colorscale=l,                
        colorbar=dict(thickness=20,title={
        'text': 'Timesteps','side':'bottom'},
           tick0=0,dtick=1,x=0.8, y=0.4, len=0.75)),
        text=[str(k) for k in range(len(a))],hoverinfo='text',showlegend=False))
fig.add_trace(go.Scatter3d(x=x[mask],y=y[mask],z=z[mask],mode='markers',showlegend=False))
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0),hovermode='closest', width=700, height=450,title={
        'text': 'Trajectory using DQN',
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'}
           )
fig.show()