In [None]:
import warnings 

warnings.filterwarnings('ignore', category=DeprecationWarning)

import os

os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'

from pathlib import Path

import hydra
import numpy as np
import torch
from dm_env import specs

import dmc
import utils
from logger import Logger
from replay_buffer import make_replay_loader
from train_offline import get_domain
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm.notebook import trange

In [None]:
class Encoder(nn.Module):
    def __init__(self, obs_shape, output_shape):
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = 32 * 12 * 12

        self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU())
        self.fc1 = nn.Linear(self.repr_dim, output_shape)
        self.fc2 = nn.Linear(output_shape, output_shape)
        self.relu = nn.ReLU()
        self.apply(utils.weight_init)

    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.reshape(h.shape[0], -1)
        return self.fc2(self.relu(self.fc1(h)))

In [None]:
seed = 0
lr = 1e-3
batch_size = 128
device = torch.device("cuda")
task = "walker_walk"
replay_buffer_dir = "datasets"
replay_buffer_size = 1000000
expl_agent = "proto"
replay_buffer_num_workers = 0
discount = 0.99

In [None]:
work_dir = Path.cwd()
utils.set_seed_everywhere(seed)
device = torch.device(device)

# create envs
env = dmc.make(task, seed=seed)


# create replay buffer
data_specs = (env.observation_spec(), env.action_spec(), env.reward_spec(),
              env.discount_spec())

# create data storage
domain = get_domain(task)
datasets_dir = work_dir / replay_buffer_dir
replay_dir = datasets_dir.resolve() / domain / expl_agent / 'buffer_img'
print(f'replay dir: {replay_dir}')

replay_loader = make_replay_loader(
    env,
    replay_dir,
    replay_buffer_size,
    batch_size,
    replay_buffer_num_workers,
    discount,
    relabel=False,
)
replay_iter = iter(replay_loader)

In [None]:
batch = next(replay_iter)
(
    obs,
    action,
    reward,
    discount,
    next_obs, 
    obs_image, 
    next_obs_image, 
    joint_state, 
    next_joint_state
) = batch

In [None]:
plt.imshow(obs_image[0])
plt.show()
plt.imshow(obs_image[1])
plt.show()

In [None]:
model = Encoder((3, 84, 84), joint_state.shape[-1])
optimizer = optim.Adam(model.parameters(), lr=lr)
model.to(device)

In [None]:
losses = []
for i in trange(5000):
    # try to evaluate
    batch = next(replay_iter)
    (
        obs,
        action,
        reward,
        discount,
        next_obs, 
        obs_image, 
        next_obs_image, 
        joint_state, 
        next_joint_state
    ) = batch
    
    joints_pred = model(obs_image.permute(0, 3, 1, 2).to(device))
    loss = torch.mean((joint_state.to(device) - joints_pred) ** 2)
    optimizer.zero_grad()
    loss.backward()
    losses.append(loss.item())
    optimizer.step()

In [None]:
plt.plot(losses)