In [2]:
import sys
sys.path.append('../')
from deep_rl.gridworld import ReachGridWorld, PickGridWorld, PORGBEnv, GoalManager, ScaleObsEnv
from deep_rl.network import *
from deep_rl.utils import *
from train import _exp_parser, get_visual_body, get_network, get_env_config, PickGridWorldTask
import os
import random
import argparse
import dill
import json
import copy
import itertools
import numpy as np
import matplotlib.pyplot as plt
from random import shuffle
from collections import Counter, namedtuple
from IPython.display import display
from PIL import Image
from pathlib import Path
from IPython.core.debugger import Tracer
from tqdm import tqdm

def set_seed(s):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)

set_seed(0) # set seed

# Try Fitted Q

In [3]:
n_objs = 4
action_dim = 5
feat_dim = 512
scale = 2
discount = 0.99

def get_expert(weight_path):
    visual_body = TSAMiniConvBody(
        2 + n_objs, 
        feature_dim=feat_dim,
        scale=scale,
    )
    expert = VanillaNet(action_dim, visual_body)
    # load weight
    weight_dict = expert.state_dict()
    loaded_weight_dict = {k: v for k, v in torch.load(
        weight_path,
        map_location=lambda storage, loc: storage)['network'].items()
        if k in weight_dict}
    weight_dict.update(loaded_weight_dict)
    expert.load_state_dict(weight_dict)
    return expert

def get_env(env_config):
    states = []
    positions = []
    qs = []
    reward_config = {'wall_penalty': -0.01, 'time_penalty': -0.01, 'complete_sub_task': 0.1, 'complete_all': 1, 'fail': -1}
    with open(env_config, 'rb') as f:
        env_config = dill.load(f)
    env = ScaleObsEnv(
        PickGridWorld(
                **env_config,
                min_dis=1,
                window=1,
                task_length=1,
                reward_config=reward_config,
                seed=0,
        ),
        2,
    )
    env.reset(sample_obj_pos=False)
    positions = env.unwrapped.pos_candidates
    for pos in positions:
        o, _, _, _ = env.teleport(*pos)
        states.append(o)
        qs.append(env.get_q(discount))
    return env, states, positions, qs

def rollout(env, q, horizon=100, epsilon=0.0, feat_state=False):
    states = []
    actions = []
    rewards = []
    next_states = []
    terminals = []
    qs = []
    returns = 0.0
    done = False
    state = env.reset(sample_obj_pos=False) # very important!
    for _ in range(horizon):
        if feat_state:
            states.append(q.body(tensor([state])).cpu().detach().numpy()[0])
        else:
            states.append(state)
        qval = q([state]).cpu().detach().numpy().flatten()
        qs.append(qval)
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            action = qval.argmax()
        state, r, done, _ = env.step(action) # note that info is not used
        actions.append(action)
        if feat_state:
            next_states.append(q.body(tensor([state])).cpu().detach().numpy()[0])
        else:
            next_states.append(state)
        rewards.append(r)
        terminals.append(done)
        returns += r
        if done: break
    return states, actions, next_states, rewards, terminals, qs, returns



In [56]:
%pdb on
n_expert_trajs = 20
epsilon = 0.0
feat_state = False

weight_path = '../log/pick.mask.fourroom-16.0.min_dis-1/dqn/double_q/0.190425-220424/models/step-3000000-mean-0.96'
env_config_path = '../data/env_configs/pick/fourroom-16.0'

expert = get_expert(weight_path)
env, all_states, positions, optimal_q = get_env(env_config_path)

states = []
actions = []
next_states = []
rewards = []
terminals = []
qs = []

for _ in tqdm(range(n_expert_trajs)):
    states_, actions_, next_states_, rewards_, terminals_, qs_, returns = rollout(env, expert, epsilon=epsilon, feat_state=feat_state)
    #print('expert returns:', returns)
    states.append(states_)
    actions.append(actions_)
    next_states.append(next_states_)
    rewards.append(rewards_)
    terminals.append(terminals_)
    qs.append(qs_)

data = dict(
    states=np.concatenate(states),
    actions=np.concatenate(actions),
    next_states=np.concatenate(next_states),
    rewards=np.concatenate(rewards),
    terminals=np.concatenate(terminals),
    expert_q=np.concatenate(qs),
)
print('num of transitions:', len(data['states']))

# input: experiences, feature extractor
# output: linear layer
def fitted_q(data, body, feat_dim, action_dim):
    A = np.zeros((feat_dim * action_dim + action_dim, feat_dim * action_dim + action_dim))
    b = np.zeros(feat_dim * action_dim + action_dim)
    N = len(data['states'])

    pbar = tqdm(total=N)
    for i, transition in enumerate(zip(data['states'], data['actions'], data['next_states'], data['rewards'], data['terminals'])):
        state, action, next_state, reward, terminal = transition
        phi = np.zeros(feat_dim * action_dim + action_dim)
        phi[feat_dim * action: feat_dim * (action + 1)] = body(tensor([state])).detach().cpu().numpy()[0]
        phi[feat_dim * action_dim + action] = 1
        b += reward * phi / N
        if terminal:
            A += np.outer(phi, phi) / N
        else:
            next_phi = np.zeros(feat_dim * action_dim + action_dim)
            s_idx = feat_dim * data['actions'][i+1] # assume trajectories is contiguous
            next_phi[s_idx: s_idx + feat_dim] = body(tensor([next_state])).detach().cpu().numpy()[0]
            next_phi[feat_dim * action_dim + data['actions'][i+1]] = 1
            A += np.outer(phi, phi - discount * next_phi) / N
        pbar.update(1)
    pbar.close() 

    # update weight
    total_weight = np.linalg.lstsq(A, b)[0]
    weight = total_weight[:-action_dim].reshape(-1, feat_dim).T
    bias = total_weight[-action_dim:]
    return weight, bias
    

weight, bias = fitted_q(data, expert.body, feat_dim, action_dim)
estimate_q = expert.body(tensor(data['states'])).detach().cpu().numpy() @ weight + bias
print(((estimate_q - data['expert_q']) ** 2).mean())
print('difference between argmax:', (estimate_q.argmax(1) != data['expert_q'].argmax(1)).sum())





  0%|          | 0/20 [00:00<?, ?it/s][A[A[A[A

Automatic pdb calling has been turned ON
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
maps: [(0, 'fourroom-16')]
tasks: [(0, ('A',)), (1, ('B',)), (2, ('C',)), (3, ('D',))]
train: [(0, 0)]
test: [(0, 0)]






 10%|█         | 2/20 [00:00<00:00, 18.35it/s][A[A[A[A



 20%|██        | 4/20 [00:00<00:00, 17.47it/s][A[A[A[A



 30%|███       | 6/20 [00:00<00:00, 16.59it/s][A[A[A[A



 45%|████▌     | 9/20 [00:00<00:00, 17.82it/s][A[A[A[A



 55%|█████▌    | 11/20 [00:00<00:00, 18.12it/s][A[A[A[A



 70%|███████   | 14/20 [00:00<00:00, 19.83it/s][A[A[A[A



 80%|████████  | 16/20 [00:00<00:00, 18.93it/s][A[A[A[A



 90%|█████████ | 18/20 [00:00<00:00, 18.20it/s][A[A[A[A



100%|██████████| 20/20 [00:01<00:00, 19.17it/s][A[A[A[A



  0%|          | 0/291 [00:00<?, ?it/s][A[A[A[A



  1%|▏         | 4/291 [00:00<00:07, 36.98it/s][A[A[A[A

num of transitions: 291






  3%|▎         | 8/291 [00:00<00:08, 35.15it/s][A[A[A[A



  4%|▍         | 12/291 [00:00<00:08, 34.04it/s][A[A[A[A



  5%|▌         | 16/291 [00:00<00:08, 33.30it/s][A[A[A[A



  7%|▋         | 20/291 [00:00<00:08, 33.05it/s][A[A[A[A



  8%|▊         | 24/291 [00:00<00:08, 32.61it/s][A[A[A[A



 10%|▉         | 28/291 [00:00<00:08, 32.32it/s][A[A[A[A



 11%|█         | 32/291 [00:00<00:07, 32.38it/s][A[A[A[A



 12%|█▏        | 36/291 [00:01<00:07, 32.15it/s][A[A[A[A



 14%|█▎        | 40/291 [00:01<00:07, 32.01it/s][A[A[A[A



 15%|█▌        | 44/291 [00:01<00:07, 32.14it/s][A[A[A[A



 16%|█▋        | 48/291 [00:01<00:07, 32.01it/s][A[A[A[A



 18%|█▊        | 52/291 [00:01<00:07, 31.90it/s][A[A[A[A



 19%|█▉        | 56/291 [00:01<00:07, 31.82it/s][A[A[A[A



 21%|██        | 60/291 [00:01<00:07, 31.78it/s][A[A[A[A



 22%|██▏       | 64/291 [00:01<00:07, 31.99it/s][A[A[A[A



 23%|██▎       | 68/291 [00:02<00:06,

0.04808895173094181
difference between argmax: 113


# Meta Linear Q

In [9]:
# D_1, ..., D_n
# D, body / \phi -> A, b -> w(\phi)
# Q(\phi) - Q_E as loss

n_expert_trajs = 30
epsilon = 0.0

weight_path = '../log/pick.mask.fourroom-16.0.min_dis-1/dqn/double_q/0.190425-220424/models/step-3000000-mean-0.96'
env_config_path = '../data/env_configs/pick/fourroom-16.0'

expert = get_expert(weight_path)
env, all_states, positions, optimal_q = get_env(env_config_path)

states = []
actions = []
next_states = []
rewards = []
terminals = []
qs = []

for _ in tqdm(range(n_expert_trajs)):
    states_, actions_, next_states_, rewards_, terminals_, qs_, returns = rollout(env, expert, epsilon=epsilon)
    #print('expert returns:', returns)
    states.append(states_)
    actions.append(actions_)
    next_states.append(next_states_)
    rewards.append(rewards_)
    terminals.append(terminals_)
    qs.append(qs_)

data = dict(
    states=np.concatenate(states),
    actions=np.concatenate(actions),
    next_states=np.concatenate(next_states),
    rewards=np.concatenate(rewards),
    terminals=np.concatenate(terminals),
    expert_q=np.concatenate(qs),
)
print('num of transitions:', len(data['states']))

def fitted_q(data, body, feat_dim, action_dim):
    A = torch.zeros(feat_dim * action_dim + action_dim, feat_dim * action_dim + action_dim)
    b = torch.zeros(feat_dim * action_dim + action_dim)
    N = len(data['states'])

    feats = body(tensor(data['states'])).repeat(1, action_dim)
    a_vec = one_hot.encode(tensor(data['actions'], torch.long), action_dim)
    phis = torch.cat([feats * a_vec.repeat_interleave(feat_dim, 1), a_vec], 1)
    
    A = torch.matmul(phis.t(), phis - discount * tensor(1 - data['terminals']).unsqueeze(1) * phis.roll(-1, 0)) / N
    b = torch.matmul(phis.t(), tensor(data['rewards'])) / N
    
    # update weight
    print(torch.isnan(A).any())
    total_weight = torch.matmul(torch.inverse(A + 1e-4 * torch.eye(A.shape[0])), b)
    weight = total_weight[:-action_dim].view(-1, feat_dim).t()
    bias = total_weight[-action_dim:]
    return weight, bias


model = copy.deepcopy(expert)
optim = torch.optim.RMSprop(
        filter(lambda p: p.requires_grad, model.parameters()), lr=0.00025, alpha=0.95, eps=0.01, centered=True)
#weight, bias = torch.randn(feat_dim, action_dim), torch.randn(action_dim)
for i in range(100):
    weight, bias = fitted_q(data, model.body, feat_dim, action_dim)
    model.fc_head.weight.data.copy_(weight.t())
    model.fc_head.bias.data.copy_(bias)
    estimate_q = model(data['states'])
    #estimate_q = torch.matmul(model.body(tensor(data['states'])), weight) + bias
    loss = F.mse_loss(estimate_q, tensor(data['expert_q']))
    print('{}-th loss:'.format(i), loss.detach().cpu().numpy())
    optim.zero_grad()
    loss.backward()
    optim.step()

  0%|          | 0/30 [00:00<?, ?it/s]

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
maps: [(0, 'fourroom-16')]
tasks: [(0, ('A',)), (1, ('B',)), (2, ('C',)), (3, ('D',))]
train: [(0, 0)]
test: [(0, 0)]


100%|██████████| 30/30 [00:01<00:00, 29.81it/s]


num of transitions: 420
tensor(0, dtype=torch.uint8)
0-th loss: 0.036114257
tensor(0, dtype=torch.uint8)
1-th loss: 0.034457173
tensor(0, dtype=torch.uint8)
2-th loss: 0.03245506
tensor(0, dtype=torch.uint8)
3-th loss: 0.031068215
tensor(0, dtype=torch.uint8)
4-th loss: 0.029790074
tensor(0, dtype=torch.uint8)
5-th loss: 0.02801595
tensor(0, dtype=torch.uint8)
6-th loss: 0.026512412
tensor(0, dtype=torch.uint8)
7-th loss: 0.02571906
tensor(0, dtype=torch.uint8)
8-th loss: 0.024330597
tensor(0, dtype=torch.uint8)
9-th loss: 0.023701595
tensor(0, dtype=torch.uint8)
10-th loss: 0.022562206
tensor(0, dtype=torch.uint8)
11-th loss: 0.021584546
tensor(0, dtype=torch.uint8)
12-th loss: 0.02064554
tensor(0, dtype=torch.uint8)
13-th loss: 0.020065848
tensor(0, dtype=torch.uint8)
14-th loss: 0.019391432
tensor(0, dtype=torch.uint8)
15-th loss: 0.018746924
tensor(0, dtype=torch.uint8)
16-th loss: 0.018230017
tensor(0, dtype=torch.uint8)
17-th loss: 0.017486038
tensor(0, dtype=torch.uint8)
18-th l

KeyboardInterrupt: 

In [10]:
# import torch
# print(torch.__version__)

# class DummyModule(torch.nn.Module):
#     def forward(self, x):
#         V = torch.Tensor(2, 2)
#         V[0, 0] = x
#         return torch.sum(V * 3)


# x = torch.tensor([1], requires_grad=True)
# r = DummyModule()(x)
# r.backward()
# print(x.grad)


print(torch.__version__)

1.1.0
