In [153]:
import os
from pathlib import Path
import shutil
import time
from typing import Dict, List, Union
import skill_generator.models.skill_generator as model_sg
import hulc
import torch

In [154]:
def get_all_checkpoints(experiment_folder: Path) -> List:
    if experiment_folder.is_dir():
        checkpoint_folder = experiment_folder / "saved_models"
        if checkpoint_folder.is_dir():
            checkpoints = sorted(Path(checkpoint_folder).iterdir(), key=lambda chk: chk.stat().st_mtime)
            if len(checkpoints):
                return [chk for chk in checkpoints if chk.suffix == ".ckpt"]
    return []

In [155]:
def get_last_checkpoint(experiment_folder: Path) -> Union[Path, None]:
    # return newest checkpoint according to creation time
    checkpoints = get_all_checkpoints(experiment_folder)
    if len(checkpoints):
        return checkpoints[-1]
    return None

In [177]:
def _sample(mu: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    eps = torch.randn(*mu.size()).to(mu)
    return mu + scale * eps

In [178]:
def skill_type_classifier(actions, magic_scale=[0.8, 1.3, 1.0]):
    energy = 0.
    gripper_energy = 0.
    _, T, _ = actions.shape
    for i in range(T):
        energy += abs(actions[:, i, :6])
    for i in range(T - 1):
        gripper_energy += abs(actions[:, i + 1, 6] - actions[:, i, 6])

    translation = (energy[:, 0] + energy[:, 1] + energy[:, 2]) / 3
    rotation = (energy[:, 3] + energy[:, 4] + energy[:, 5]) / 3
    gripper = gripper_energy
    translation /= magic_scale[0]
    rotation /= magic_scale[1]
    gripper /= magic_scale[2]
    t = torch.stack([translation, rotation, gripper], dim=-1)
    B, _ = t.shape
    skill_types = torch.argmax(t, dim=-1)
    return skill_types

In [179]:
batch = 10000
# load_checkpoint
sg_chk_path = './sg_runs/2022-12-04/09-41-58'
sg_chk_path = Path(hulc.__file__).parent.parent / sg_chk_path
chk = get_last_checkpoint(sg_chk_path)
skill_generator = getattr(model_sg, 'SkillGenerator').load_from_checkpoint(chk.as_posix())
skill_generator.freeze()
prior_locator = skill_generator.prior_locator.eval()
action_decoder = skill_generator.decoder.eval()

priors = prior_locator(repeat=batch)
skill_len = torch.tensor(5)

t_mu = priors['p_mu'][:,0,:]
t_scale = priors['p_scale'][:,0,:]

t_sampled = _sample(t_mu, t_scale)

r_mu = priors['p_mu'][:,1,:]
r_scale = priors['p_scale'][:,1,:]
r_sampled = _sample(r_mu, r_scale)

g_mu = priors['p_mu'][:,2,:]
g_scale = priors['p_scale'][:,2,:]
g_sampled = _sample(g_mu, g_scale)

t_actions = action_decoder(t_sampled, skill_len.repeat(batch))
r_actions = action_decoder(r_sampled, skill_len.repeat(batch))
g_actions = action_decoder(g_sampled, skill_len.repeat(batch))

rate_t = torch.sum(skill_type_classifier(t_actions) == 0) / batch
rate_r = torch.sum(skill_type_classifier(r_actions) == 1) / batch
rate_g = torch.sum(skill_type_classifier(g_actions) == 2) / batch

print('translation rate: ', rate_t)
print('rotation rate: ', rate_r)
print('grasp rate: ', rate_g)

translation rate:  tensor(0.7839)
rotation rate:  tensor(0.5609)
grasp rate:  tensor(0.7317)
