In [1]:
import sys
sys.path.append('../')
from deep_rl.gridworld import ReachGridWorld, PickGridWorld, PORGBEnv, GoalManager, ScaleObsEnv
from deep_rl.network import *
from deep_rl.utils import *
from train import _exp_parser, get_visual_body, get_network, get_env_config, PickGridWorldTask
import os
import random
import argparse
import dill
import json
import itertools
import numpy as np
import matplotlib.pyplot as plt
from random import shuffle
from collections import Counter, namedtuple
from IPython.display import display
from PIL import Image
from pathlib import Path
from IPython.core.debugger import Tracer

def set_seed(s):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)

set_seed(0) # set seed

pygame 1.9.4
Hello from the pygame community. https://www.pygame.org/contribute.html


# Try Fitted Q

In [2]:
n_objs = 4
action_dim = 5
feat_dim = 512
scale = 2
discount = 0.99

def get_expert(weight_path):
    visual_body = TSAMiniConvBody(
        2 + n_objs, 
        feature_dim=feat_dim,
        scale=scale,
    )
    expert = VanillaNet(action_dim, visual_body)
    # load weight
    weight_dict = expert.state_dict()
    loaded_weight_dict = {k: v for k, v in torch.load(
        weight_path,
        map_location=lambda storage, loc: storage)['network'].items()
        if k in weight_dict}
    weight_dict.update(loaded_weight_dict)
    expert.load_state_dict(weight_dict)
    return expert

def get_env(env_config):
    states = []
    positions = []
    qs = []
    reward_config = {'wall_penalty': -0.01, 'time_penalty': -0.01, 'complete_sub_task': 0.1, 'complete_all': 1, 'fail': -1}
    with open(env_config, 'rb') as f:
        env_config = dill.load(f)
    env = ScaleObsEnv(
        PickGridWorld(
                **env_config,
                min_dis=1,
                window=1,
                task_length=1,
                reward_config=reward_config,
                seed=0,
        ),
        2,
    )
    env.reset(sample_obj_pos=False)
    positions = env.unwrapped.pos_candidates
    for pos in positions:
        o, _, _, _ = env.teleport(*pos)
        states.append(o)
        qs.append(env.get_q(discount))
    return env, states, positions, qs

def rollout(env, q, horizon=100, epsilon=0.0, feat_state=False):
    states = []
    actions = []
    rewards = []
    next_states = []
    terminals = []
    qs = []
    returns = 0.0
    done = False
    state = env.reset(sample_obj_pos=False) # very important!
    for _ in range(horizon):
        if feat_state:
            states.append(q.body(tensor([state])).cpu().detach().numpy()[0])
        else:
            states.append(state)
        qval = q([state]).cpu().detach().numpy().flatten()
        qs.append(qval)
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            action = qval.argmax()
        state, r, done, _ = env.step(action) # note that info is not used
        actions.append(action)
        if feat_state:
            next_states.append(q.body(tensor([state])).cpu().detach().numpy()[0])
        else:
            next_states.append(state)
        rewards.append(r)
        terminals.append(done)
        returns += r
        if done: break
    return states, actions, next_states, rewards, terminals, qs, returns



In [3]:
%pdb on
n_expert_trajs = 500
epsilon = 0.0
feat_state = True

weight_path = '../log/pick.mask.fourroom-16.0.min_dis-1/dqn/double_q/0.190425-220424/models/step-3000000-mean-0.96'
env_config_path = '../data/env_configs/pick/fourroom-16.0'


expert = get_expert(weight_path)
env, all_states, positions, optimal_q = get_env(env_config_path)

states = []
actions = []
next_states = []
rewards = []
terminals = []
qs = []

for _ in range(n_expert_trajs):
    states_, actions_, next_states_, rewards_, terminals_, qs_, returns = rollout(env, expert, epsilon=epsilon, feat_state=feat_state)
    print('expert returns:', returns)
    states.append(states_)
    actions.append(actions_)
    next_states.append(next_states_)
    rewards.append(rewards_)
    terminals.append(terminals_)
    qs.append(qs_)

data = dict(
    states=np.concatenate(states),
    actions=np.concatenate(actions),
    next_states=np.concatenate(next_states),
    rewards=np.concatenate(rewards),
    terminals=np.concatenate(terminals),
    expert_q=np.concatenate(qs),
)

# Fitted Q iteration
A = np.zeros((feat_dim * action_dim + action_dim, feat_dim * action_dim + action_dim))
b = np.zeros(feat_dim * action_dim + action_dim)
N = len(data['states'])

for i, transition in enumerate(zip(data['states'], data['actions'], data['next_states'], data['rewards'], data['terminals'])):
    state, action, next_state, reward, terminal = transition
    phi = np.zeros(feat_dim * action_dim + action_dim)
    phi[feat_dim * action: feat_dim * (action + 1)] = state
    phi[feat_dim * action_dim + action] = 1
    if terminal:
        A += np.outer(phi, phi) / N
    else:
        next_phi = np.zeros(feat_dim * action_dim + action_dim)
        s_idx = feat_dim * data['actions'][i+1]
        next_phi[s_idx: s_idx + feat_dim] = state
        next_phi[feat_dim * action_dim + data['actions'][i+1]] = 1
        A += np.outer(phi, phi - discount * next_phi)
    b += reward * phi / N
    if i % 10 == 0: print('at {}'.format(i))
    
# update weight
total_weight = np.linalg.lstsq(A, b)[0]
weight = total_weight[:-action_dim].reshape(-1, action_dim)
bias = total_weight[-action_dim:]

estimate_q = data['states'] @ weight + bias
print(((estimate_q - data['expert_q']) ** 2).mean())

Automatic pdb calling has been turned ON
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
maps: [(0, 'fourroom-16')]
tasks: [(0, ('A',)), (1, ('B',)), (2, ('C',)), (3, ('D',))]
train: [(0, 0)]
test: [(0, 0)]
expert returns: 0.91
expert returns: 1.0
expert returns: 0.9800000000000001
expert returns: 0.88
expert returns: 0.9
expert returns: 0.9400000000000001
expert returns: 0.9600000000000001
expert returns: 1.0
expert returns: 0.9700000000000001
expert returns: 0.9400000000000001
expert returns: 0.9800000000000001
expert returns: 1.0
expert returns: 0.9600000000000001
expert returns: 1.03
expert returns: 0.93
expert returns: 0.9600000000000001
expert returns: 0.87
expert returns: 0.9400000000000001
expert returns: 0.9500000000000001
expert returns: 0.9900000000000001
expert returns: 0.93
expert returns: 1.03
expert returns: 1.05
expert returns: 1.03
expert returns: 0.9
expert returns: 0.9800000000000001
expert returns: 0.93
exp

expert returns: 0.92
expert returns: 0.88
expert returns: 0.91
expert returns: 0.93
expert returns: 0.93
expert returns: 0.88
expert returns: 0.9600000000000001
expert returns: 0.9600000000000001
expert returns: 0.9900000000000001
expert returns: 1.02
expert returns: 0.9700000000000001
expert returns: 1.01
expert returns: 0.9400000000000001
expert returns: 1.02
expert returns: 1.03
expert returns: 1.03
expert returns: 0.93
expert returns: 0.9500000000000001
expert returns: 0.9900000000000001
expert returns: 0.9800000000000001
expert returns: 0.93
expert returns: 0.9800000000000001
expert returns: 1.06
expert returns: 0.88
expert returns: 0.9
expert returns: 0.89
expert returns: 0.9
expert returns: 0.88
expert returns: 0.92
expert returns: 0.9700000000000001
expert returns: 1.08
expert returns: 0.92
expert returns: 0.9
expert returns: 0.9500000000000001
expert returns: 0.8300000000000001
expert returns: 1.03
expert returns: 0.9700000000000001
expert returns: 0.85
expert returns: 0.93
ex

at 3770
at 3780
at 3790
at 3800
at 3810
at 3820
at 3830
at 3840
at 3850
at 3860
at 3870
at 3880
at 3890
at 3900
at 3910
at 3920
at 3930
at 3940
at 3950
at 3960
at 3970
at 3980
at 3990
at 4000
at 4010
at 4020
at 4030
at 4040
at 4050
at 4060
at 4070
at 4080
at 4090
at 4100
at 4110
at 4120
at 4130
at 4140
at 4150
at 4160
at 4170
at 4180
at 4190
at 4200
at 4210
at 4220
at 4230
at 4240
at 4250
at 4260
at 4270
at 4280
at 4290
at 4300
at 4310
at 4320
at 4330
at 4340
at 4350
at 4360
at 4370
at 4380
at 4390
at 4400
at 4410
at 4420
at 4430
at 4440
at 4450
at 4460
at 4470
at 4480
at 4490
at 4500
at 4510
at 4520
at 4530
at 4540
at 4550
at 4560
at 4570
at 4580
at 4590
at 4600
at 4610
at 4620
at 4630
at 4640
at 4650
at 4660
at 4670
at 4680
at 4690
at 4700
at 4710
at 4720
at 4730
at 4740
at 4750
at 4760
at 4770
at 4780
at 4790
at 4800
at 4810
at 4820
at 4830
at 4840
at 4850
at 4860
at 4870
at 4880
at 4890
at 4900
at 4910
at 4920
at 4930
at 4940
at 4950
at 4960
at 4970
at 4980
at 4990
at 5000
at 5010




0.8327242301950585


In [8]:
print((estimate_q.argmax(1) != data['expert_q'].argmax(1)).sum())

7102


In [5]:
print()

[2 0 0 ... 2 2 4]
