In [1]:
from motif.eval_utils import get_batch_dict_for_reward_model, reward_model_with_mb
from motif.reward_model import RewardModel

In [2]:
import os
import smart_settings

In [3]:
import torch
import numpy as np

In [4]:
model_dir = "/Users/chrisgumbsch/Documents/SENSEI/code/poke_transfer2/pokemon_motif/best_val_motif_wd49"
params = smart_settings.load(os.path.join(model_dir, "settings.json"), make_immutable=False)
motif_reward_model = RewardModel(params["reward_model_params"]["model_params"], device=torch.device("cpu"))
motif_reward_model.load(os.path.join(model_dir, f"checkpoint_{49}"), device=torch.device("cpu"))

  state_dicts = torch.load(path, map_location=device)


In [5]:
class RewardWrapperEmbodied():
    def __init__(self, env, model_dir='', img_key='image',
                 sliding_avg=0, deriv_scaling=False, deriv_ema_alpha = 0.09, deriv_scaling_alpha = 0.35,
                 clipping_min=None, clipping_max=None,
                 model_cpt_id=49, device=torch.device("cpu")):
        """Constructor for the Reward wrapper.

        Args:
            env: Environment to be wrapped.
            img_key: Key of obs to get correct image size (image vs hq_img in Dreamer codebase)
        """

        self.reward_key = "motif_reward"

        self.img_key = img_key

        assert not (sliding_avg and deriv_scaling), "Both variables should not be True at the same time"

        self.sliding_avg = sliding_avg
        self.window = []

        if deriv_scaling:
            self.ema_deriv_estimator = EMA_Derivative_Estimator(alpha=deriv_ema_alpha)
        else:
            self.ema_deriv_estimator = None
        self.deriv_scaling_alpha = deriv_scaling_alpha

        self.clipping_min = clipping_min if clipping_min is not None else -np.inf
        self.clipping_max = clipping_max if clipping_max is not None else np.inf
        assert self.clipping_max > self.clipping_min, "Max clipping value has to be greater than min!"

        if model_dir:
            params = smart_settings.load(os.path.join(model_dir, "settings.json"), make_immutable=False)
            self.motif_reward_model = RewardModel(params["reward_model_params"]["model_params"], device=device)
            self.motif_reward_model.load(
                os.path.join(model_dir, f"checkpoint_{model_cpt_id}"), device=device
            )
        else:
            # empty model dir is provided -> no motif model traiend -> motif reward is always 0
            self.motif_reward_model = None
        # print(next(self.motif_reward_model.parameters()).is_cuda)

    def step(
            self, action
    ):
        if action['reset'] or self._done:
            self.reset()
        obs_dict = self.env.step(action)
        return self.update_obs_dict(obs_dict)

    def reset(self):
        self.window = []
        if self.ema_deriv_estimator is not None:
            self.ema_deriv_estimator.reset()

    def update_obs_dict(self, obs_dict):
        # Saving the simulator state as well for re-rendering afterwards
        return self.compute_motif_reward(obs_dict)

    @torch.no_grad()
    def compute_motif_reward(self, obs_dict):

        if self.motif_reward_model is None:
            # no motif reward module provided ("Plan2Explore phase")
            motif_reward = 0.0
        else:
            # for now only support for one option!
            rollout_images = None
            rollout_images_left = None
            obs_vec = None
            batch_dim = -1
            if self.motif_reward_model.encoder.use_obs_vec:
                if "qpos_robot" in obs_dict.keys():
                    obs_vec = np.concatenate(
                        [
                            obs_dict["qpos_robot"],
                            obs_dict["qvel_robot"],
                            obs_dict["end_effector"],
                            obs_dict["qpos_objects"],
                            obs_dict["qvel_objects"],
                        ]
                    )
                elif "inv_glyphs" in obs_dict.keys():
                    key_glyph_id = 2102
                    key_in_inventory = key_glyph_id in obs_dict["inv_glyphs"]
                    obs_vec = np.array([key_in_inventory * 1.])[None]
                else:
                    raise NotImplementedError
                # I put this here as a dummy thing because right now it actually doesn't exist in this form!
                # And we would need to make it more generic and non-robodesk specific with the keys!
            if self.motif_reward_model.encoder.use_image:
                if not self.motif_reward_model.encoder.resize_image:
                    im_size = self.motif_reward_model.encoder.image_resolution
                    if len(obs_dict[self.img_key].shape)==4:
                        batch_dim = obs_dict[self.img_key].shape[0]
                    else:
                        assert im_size == obs_dict[self.img_key].shape[0]
                else:
                    im_size = obs_dict[self.img_key].shape[0]
                rollout_images = obs_dict[self.img_key].reshape(batch_dim, im_size, im_size, 3)
            else:
                raise NotImplementedError
            batch_dict = get_batch_dict_for_reward_model(
                self.motif_reward_model, rollout_images, obs_vec, rollout_images_left
            )
            reward_dict = reward_model_with_mb(self.motif_reward_model, batch_dict)

            motif_reward = reward_dict.rewards.detach().cpu().numpy()#[0]#.item()
            if batch_dim == -1:
                motif_reward = motif_reward[0]

        if self.sliding_avg:
            self.window.append(motif_reward)
            if len(self.window) > self.sliding_avg:
                self.window.pop(0)
            motif_reward = np.mean(self.window)
        elif self.ema_deriv_estimator is not None:
            ema_deriv = self.ema_deriv_estimator.estimate(motif_reward)
            ema_deriv = np.maximum(1e-3, ema_deriv)
            scale_for_motif = np.exp(-1 / (self.deriv_scaling_alpha * ema_deriv))
            motif_reward = motif_reward * scale_for_motif

        obs_dict[self.reward_key] = np.clip(motif_reward, a_min=self.clipping_min, a_max=self.clipping_max)
        return obs_dict


In [20]:
wrapper = RewardWrapperEmbodied(env=None, model_dir = "/Users/chrisgumbsch/Documents/SENSEI/code/poke_transfer2/pokemon_motif/best_val_motif_wd49", clipping_min=-100, clipping_max=100, sliding_avg=25)


In [9]:
rollout_dir = "/Users/chrisgumbsch/Documents/pokemon/poke_transferred/pokegym_log_sensei6_0/train_0/train_rollout_20240906T075729/"

In [10]:

from PIL import Image
from skimage.transform import resize
def resize_to_obs(img):
    return (255 * resize(img[:, :, :3], (64, 64, 3))).astype(np.uint8)


In [21]:
for i in range(0, 552):
    img_dir = rollout_dir + f"frame_00{str(i).zfill(3)}.png"
    img1 = resize_to_obs(np.asarray(Image.open(img_dir))[:, :, :3])
    test_dict = {'image': img1}
    print(i, wrapper.compute_motif_reward(test_dict)['motif_reward'])

0 -2.592322826385498
1 -2.592322826385498
2 -2.592322826385498
3 -2.6354598999023438
4 -2.649954319000244
5 -2.6024792194366455
6 -2.410912036895752
7 -2.251497745513916
8 -2.1031668186187744
9 -2.0225205421447754
10 -1.7889279127120972
11 -1.5795854330062866
12 -1.5516194105148315
13 -1.5974076986312866
14 -1.4172370433807373
15 -0.8553590774536133
16 -0.3294253647327423
17 0.03278642147779465
18 0.3467276096343994
19 0.6437705755233765
20 1.0008128881454468
21 1.4429192543029785
22 1.8467158079147339
23 2.120954751968384
24 2.3425726890563965
25 2.785067558288574
26 3.2040157318115234
27 3.6282427310943604
28 4.108090400695801
29 4.5561628341674805
30 4.962926387786865
31 5.317322254180908
32 5.703067779541016
33 5.9680633544921875
34 6.2863616943359375
35 6.494763374328613
36 6.685171127319336
37 6.946932315826416
38 7.225092887878418
39 7.386293888092041
40 7.263065814971924
41 7.099915027618408
42 6.989477634429932
43 6.8564677238464355
44 6.748262405395508
45 6.550333023071289
46

In [13]:
model_dir = "/Users/chrisgumbsch/Documents/SENSEI/code/poke_transfer2/pokemon_motif/best_val_motif_new_wd30"
params = smart_settings.load(os.path.join(model_dir, "settings.json"), make_immutable=False)
motif_reward_model = RewardModel(params["reward_model_params"]["model_params"], device=torch.device("cpu"))
motif_reward_model.load(os.path.join(model_dir, f"checkpoint_{49}"), device=torch.device("cpu"))

  state_dicts = torch.load(path, map_location=device)


In [23]:
wrapper2 = RewardWrapperEmbodied(env=None, model_dir = "/Users/chrisgumbsch/Documents/SENSEI/code/poke_transfer2/pokemon_motif/best_val_motif_new_wd30", clipping_min=-100, clipping_max=100, sliding_avg=25)


In [28]:
xs = (1)
if xs:
    print("Got here")

Got here


In [24]:
for i in range(0, 552):
    img_dir = rollout_dir + f"frame_00{str(i).zfill(3)}.png"
    img1 = resize_to_obs(np.asarray(Image.open(img_dir))[:, :, :3])
    test_dict = {'image': img1}
    print(i, wrapper2.compute_motif_reward(test_dict)['motif_reward'])

0 -4.638790607452393
1 -4.638790607452393
2 -4.638790607452393
3 -4.685917854309082
4 -4.783539295196533
5 -4.7548980712890625
6 -4.690854549407959
7 -4.597592353820801
8 -4.429993152618408
9 -4.367526054382324
10 -4.227407455444336
11 -3.9737348556518555
12 -3.86816143989563
13 -3.7366421222686768
14 -3.5412631034851074
15 -3.011653423309326
16 -2.541410446166992
17 -2.1987485885620117
18 -1.9489734172821045
19 -1.7198585271835327
20 -1.3869237899780273
21 -0.9647157788276672
22 -0.5951377749443054
23 -0.31607523560523987
24 -0.07771375775337219
25 0.3714440166950226
26 0.741273820400238
27 1.1372300386428833
28 1.5911576747894287
29 1.996476411819458
30 2.4112656116485596
31 2.7821755409240723
32 3.168322801589966
33 3.429257392883301
34 3.734895706176758
35 3.995413303375244
36 4.183475017547607
37 4.412973403930664
38 4.578169822692871
39 4.7130632400512695
40 4.588732719421387
41 4.443906307220459
42 4.318098068237305
43 4.2221455574035645
44 4.147317886352539
45 3.962020874023437

In [57]:
wrapper.compute_motif_reward(test_dict)['motif_reward']

array([-2.5923228], dtype=float32)

In [70]:
import os
import numpy as np
from embodied.core.path import Path


In [71]:
def scan(directory, capacity=None, shorten=0):
    directory = Path(directory)
    filenames, total = [], 0
    for filename in sorted(directory.glob('*.npz')):
        if capacity and total >= capacity:
            break
        # print(filename)
        filenames.append(filename)
        total += max(0, int(filename.stem.split('-')[3]) - shorten)
    # print(total)
    return filenames

In [79]:

# Function to overwrite some values in replay buffer, e.g. vlm_rewards from past motif annotations.
# Takes replay buffer from working_dir and applies a function modify( ) to each dict.
# The new replay buffer is stored in f"{target_dir}/replay"
def copy_modified_replay(working_dir, target_dir, modify):
    data_dir = f"{working_dir}/replay"
    tar_dir = f"{target_dir}/replay"
    #os.makedirs(tar_dir, exist_ok=True)
    filenames = scan(data_dir, capacity=None, shorten=0)
    for filename in filenames:
        with Path(filename).open("rb") as f:
            x = np.load(f)
            y = modify(dict(**x))
            #target_f = os.path.join(tar_dir, filename.stem)
            #np.savez(target_f, **y)

In [150]:
def relabel_motif(obs, motif_wrapper):
    obs.pop('motif_reward')
    obs = wrapper.compute_motif_reward(obs)
    print(obs.keys())
    return obs

In [151]:
relabel_motif1 = lambda obs: relabel_motif(obs, wrapper)

In [142]:
def printify(obs):
    obs.pop('reset')
    #for k in obs.keys():
    #    print(k, "with", obs[k].shape)
    print(obs.keys())
    return obs

In [152]:
copy_modified_replay(working_dir="~/logdir/pkmn_motif_dummy", target_dir='', modify=relabel_motif1)

dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(['image', 'reward', 'is_first', 'is_last', 'is_terminal', 'action', 'reset', 'id', 'motif_reward'])
dict_keys(