In [18]:
import os
from pathlib import Path
import shutil
import time
from typing import Dict, List, Union
import skill_generator.models.skill_generator as model_sg
import hulc
import torch

In [19]:
def get_all_checkpoints(experiment_folder: Path) -> List:
    if experiment_folder.is_dir():
        checkpoint_folder = experiment_folder / "saved_models"
        if checkpoint_folder.is_dir():
            checkpoints = sorted(Path(checkpoint_folder).iterdir(), key=lambda chk: chk.stat().st_mtime)
            if len(checkpoints):
                return [chk for chk in checkpoints if chk.suffix == ".ckpt"]
    return []

In [20]:
def get_last_checkpoint(experiment_folder: Path) -> Union[Path, None]:
    # return newest checkpoint according to creation time
    checkpoints = get_all_checkpoints(experiment_folder)
    if len(checkpoints):
        return checkpoints[-1]
    return None

In [29]:
def _sample(mu: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    eps = torch.randn(*mu.size()).to(mu)
    return mu +  1.0 * scale * eps

In [30]:
def _check_direction(dire):
    p = torch.clone(dire)
    n = torch.clone(dire)
    p[dire < 0.] = 0.
    n[dire > 0.] = 0.
    n = torch.abs(n)
    return p, n

In [35]:
def skill_classifier(actions, scale=[1.6, 2.0, 0.75], eps=0.1):
    gripper_energy = 0.
    _, T, _ = actions.shape
    energy = torch.abs(torch.sum(actions[:, :, :6], dim=1))
    for i in range(T - 1):
        gripper_energy += abs(actions[:, i + 1, 6] - actions[:, i, 6])

    translation = torch.norm(energy[:, :3], dim=1)
    rotation = torch.norm(energy[:, 3:6], dim=1)
    gripper = gripper_energy
    translation /= scale[0]
    rotation /= scale[1]
    gripper /= scale[2]
    rotation[torch.logical_and(rotation - translation > 0, rotation - translation < eps)] -= eps
    t = torch.stack([translation, rotation, gripper], dim=-1)
    B, _ = t.shape
    skill_types = torch.argmax(t, dim=-1)
    return skill_types

In [36]:
batch = 10000
# load_checkpoint
sg_chk_path = './sg_runs/2022-12-14/16-21-16'
sg_chk_path = Path(hulc.__file__).parent.parent / sg_chk_path
chk = get_last_checkpoint(sg_chk_path)
skill_generator = getattr(model_sg, 'SkillGenerator').load_from_checkpoint(chk.as_posix())
skill_generator.freeze()
prior_locator = skill_generator.prior_locator.eval()
action_decoder = skill_generator.decoder.eval()

priors = prior_locator(repeat=batch)
skill_len = torch.tensor(5)

t_mu = priors['p_mu'][:,0,:]
t_scale = priors['p_scale'][:,0,:]
t_sampled = _sample(t_mu, t_scale)

r_mu = priors['p_mu'][:,1,:]
r_scale = priors['p_scale'][:,1,:]
r_sampled = _sample(r_mu, r_scale)

g_mu = priors['p_mu'][:,2,:]
g_scale = priors['p_scale'][:,2,:]
g_sampled = _sample(g_mu, g_scale)

t_actions = action_decoder(t_sampled, skill_len.repeat(batch))
r_actions = action_decoder(r_sampled, skill_len.repeat(batch))
g_actions = action_decoder(g_sampled, skill_len.repeat(batch))

rate_t = torch.sum(skill_classifier(t_actions) == 0) / batch
rate_r = torch.sum(skill_classifier(r_actions) == 1) / batch
rate_g = torch.sum(skill_classifier(g_actions) == 2) / batch

print('left translation rate: ', rate_t)
print('rotation rate: ', rate_r)
print('grasp rate: ', rate_g)

left translation rate:  tensor(0.6588)
rotation rate:  tensor(0.3218)
grasp rate:  tensor(0.5895)
