# Training Shield

In [None]:
import torch
import os
import omnisafe
from omnisafe.adapter import ShieldAdapter
from omnisafe.utils.config import get_default_kwargs_yaml, Config

train_terminal_cfgs = {
    "algo": "PPO",
    "env_id": "SafetyPointGoal1-v0",
    "parallel": 1,
    "total_steps": 1638400,
    "device": "cpu",
    "vector_env_nums": 16,
    "torch_threads": 16,
}

agent = omnisafe.Agent(
    train_terminal_cfgs["algo"],
    train_terminal_cfgs["env_id"],
    train_terminal_cfgs=train_terminal_cfgs,
)
cfgs: Config = agent.cfgs

steps = 16_000_000
data_dir = "/home/juntao/workspace/my_omnisafe/experiments/data"
filename = f"random_{cfgs.train_cfgs.env_id}_{steps}.pt"
filename = os.path.join(data_dir, filename)

checkpoint_dir = "/home/juntao/workspace/my_omnisafe/experiments/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_filename = f"random_{cfgs.train_cfgs.env_id}_{steps}.ckpt"
checkpoint_filename = os.path.join(checkpoint_dir, checkpoint_filename)

In [None]:
data = torch.load(filename)
print(data.keys())
for k,v in data.items():
    print(f"{k} shape: {v.shape}")

In [None]:
print(data["cost"].sum())

In [None]:
from omnisafe.models.custom_modes.classifier_shield import ClassifierShield

device = torch.device("cuda:1")

shield_cfg = {
    "device": device,
    "dtype": torch.float32,
    "risk_threshold": 0.775,
    "max_resample_times": 100,
    "resample_batch_size": 20,
    "batch_size": 32,
    "risk_discount": 0.95,
    "risk_model": {
        "hidden_sizes": [64, 64],
        "activation": "relu",
        "output_activation": "sigmoid",
        "weight_initialization_mode": "xavier_uniform",
    },
    "classifier_model": {
        "hidden_sizes": [64, 64],
        "activation": "sigmoid",
        "output_activation": "softmax",
        "weight_initialization_mode": "xavier_uniform",
    },
}

env = ShieldAdapter(
    train_terminal_cfgs["env_id"], cfgs.train_cfgs.vector_env_nums, cfgs.seed, cfgs
)

shield = ClassifierShield(env, shield_cfg)

In [None]:
def flatten_vectorized_dict(d):
    return {k: v.transpose(0, 1).flatten(0, 1) for k, v in d.items()}

def dict_to_device(d, device):
    return {k: v.to(device) for k, v in d.items()}

flatten_data = flatten_vectorized_dict(data)
flatten_data = dict_to_device(flatten_data, device)
flatten_data["risk"] = flatten_data["cost"].clone().bool().float()
flatten_data["terminated"] = flatten_data["done"].bool()

print(flatten_data["risk"])



In [None]:
num_epochs:int = 2
for _ in range(num_epochs):
    shield.update(flatten_data)

In [None]:

shield.save_model(checkpoint_filename)