In [1]:
import sys
from os import path
sys.path.insert(1, path.join(sys.path[0], '..'))

In [2]:
from itertools import combinations

In [3]:
import torch
import torch.nn as nn
import numpy as np

In [4]:
from levers import IteratedLeverEnvironment
from levers.partners import FixedPatternPartner
from levers.learner import DQNAgent, HistoryShaper, Transition

In [5]:
# Experiment settings
n_eval_mc = 1
n_train_mc = 3

# Final results table
# Format: eval_id, train_id, test_pattern_id, train_patterns_id
scores = torch.zeros((70, 8, n_train_mc, n_eval_mc))
greedy_scores = torch.zeros((70, 8, n_train_mc, n_eval_mc))

In [6]:
def eval_online_q_learner(train_patterns, test_pattern, train_id):
    # Environment settings
    payoffs = [1., 1.]
    truncated_length = 100
    include_step = False
    include_payoffs = False

    # Construct environment
    env = IteratedLeverEnvironment(
        payoffs, truncated_length+1, FixedPatternPartner(test_pattern),
        include_step, include_payoffs
    )


    # History shaper settings
    hs_hidden_size = 4

    # Construct history shaper
    hist_shaper = HistoryShaper(
        hs_net=nn.LSTM(
            input_size=len(env.dummy_obs()),
            hidden_size=hs_hidden_size,
        )
    )

    # Load history shaper from experiment
    experiment_name = 'online_qlearner_all_length3_manual'
    data_dir = 'data'
    model_name = 'hs-net-pattern={tps}-eval_id={eid:02d}.pt'.format(
        tps=train_patterns,
        eid=train_id
    )
    model_path = path.join(experiment_name, data_dir, model_name)
    hist_shaper.net.load_state_dict(torch.load(model_path))


    # Learner settings
    learner_hidden_size = 4
    capacity = 16
    batch_size = 8
    lr = 0.01
    gamma = 0.99
    len_update_cycle = 10

    # Initialize DQN agent
    learner = DQNAgent(
        q_net=nn.Sequential(
            nn.Linear(hs_hidden_size, learner_hidden_size),
            nn.ReLU(),
            nn.Linear(learner_hidden_size, env.n_actions()),
        ),
        capacity=capacity,
        batch_size=batch_size,
        lr=lr,
        gamma=gamma,
        len_update_cycle=len_update_cycle
    )

    # Load q-net from experiment
    experiment_name = 'online_qlearner_all_length3_manual'
    data_dir = 'data'
    model_name = 'q-net-pattern={tps}-eval_id={eid:02d}.pt'.format(
        tps=train_patterns,
        eid=train_id
    )
    model_path = path.join(experiment_name, data_dir, model_name)
    learner.q_net.load_state_dict(torch.load(model_path))
    learner.reset()

    ret, greedy_ret, greedy_steps = 0, 0, 0
    obs = env.reset()
    obs_rep, hidden = hist_shaper.net(obs.unsqueeze(0))
    for step in range(truncated_length):
        epsilon = 1 * (1 - 4 * step / truncated_length)
        action, is_greedy = learner.act(obs_rep.squeeze(0), epsilon=epsilon)
        next_obs, reward, done = env.step(action)

        ret += reward
        greedy_ret += reward if is_greedy else 0
        greedy_steps += is_greedy

        # Compute history representation
        next_obs_rep, next_hidden = hist_shaper.net(
            next_obs.unsqueeze(0), hidden)

        # Give experience to learner and train
        learner.update_memory(
            Transition(
                obs_rep.squeeze(0).detach(),
                action, 
                next_obs_rep.squeeze(0).detach(), 
                reward, done
            )
        )
        learner.train(done)

        # Update next observation -> observation
        obs = next_obs
        obs_rep = next_obs_rep
        hidden = next_hidden

    score = ret / truncated_length
    greedy_score = greedy_ret / greedy_steps
    return score, greedy_score

In [7]:
patterns = [
    (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1),
    (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1),
]

for train_patterns_id, train_patterns in enumerate(combinations(patterns, 4)):
    print('-' * 50)
    for test_pattern_id, test_pattern in enumerate(patterns):
        for train_id in range(n_train_mc):
            for eval_id in range(n_eval_mc):
                score, greedy_score = eval_online_q_learner(
                    train_patterns, test_pattern, train_id) 
                scores[train_patterns_id, test_pattern_id, train_id, eval_id] = score
                greedy_scores[train_patterns_id, test_pattern_id, train_id, eval_id] = greedy_score
        in_train_patterns = '*' if test_pattern in train_patterns else ' '
        np.set_printoptions(precision=3)
        print('{tps_id: 2d}/70 | {testp} ({flag}): actual: {actual}, greedy: {greedy}'.format(
            tps_id=train_patterns_id+1,
            testp=test_pattern,
            flag=in_train_patterns,
            actual=scores[train_patterns_id, test_pattern_id,:,:].mean(dim=1).numpy(),
            greedy=greedy_scores[train_patterns_id, test_pattern_id,:,:].mean(dim=1).numpy(),
        ))
    avg_train_score = torch.stack([scores[train_patterns_id, pattern_id,:,:].mean() for pattern_id, pattern in enumerate(patterns) if pattern in train_patterns]).mean().item()
    avg_test_scores = torch.stack([scores[train_patterns_id, pattern_id,:,:].mean() for pattern_id, pattern in enumerate(patterns) if pattern not in train_patterns]).mean().item()
    avg_greedy_train_score = torch.stack([greedy_scores[train_patterns_id, pattern_id,:,:].mean() for pattern_id, pattern in enumerate(patterns) if pattern in train_patterns]).mean().item()
    avg_greedy_test_scores = torch.stack([greedy_scores[train_patterns_id, pattern_id,:,:].mean() for pattern_id, pattern in enumerate(patterns) if pattern not in train_patterns]).mean().item()
    print(f'Actual-train: {avg_train_score:5.2f}, Actual-test: {avg_test_scores:5.2f}')
    print(f'Greedy-train: {avg_greedy_train_score:5.2f}, Greedy-test: {avg_greedy_test_scores:5.2f}')


--------------------------------------------------
 0/70 | (0, 0, 0) (*): actual: [0.942 0.932 0.932], greedy: [1. 1. 1.]
 0/70 | (0, 0, 1) (*): actual: [0.912 0.884 0.756], greedy: [0.971 0.956 0.805]
 0/70 | (0, 1, 0) (*): actual: [0.862 0.908 0.76 ], greedy: [0.93  0.955 0.798]
 0/70 | (0, 1, 1) (*): actual: [0.8   0.88  0.734], greedy: [0.847 0.931 0.767]
 0/70 | (1, 0, 0) ( ): actual: [0.866 0.896 0.744], greedy: [0.931 0.952 0.789]
 0/70 | (1, 0, 1) ( ): actual: [0.794 0.826 0.758], greedy: [0.847 0.892 0.8  ]
 0/70 | (1, 1, 0) ( ): actual: [0.808 0.854 0.792], greedy: [0.851 0.902 0.842]
 0/70 | (1, 1, 1) ( ): actual: [0.92  0.926 0.884], greedy: [0.987 0.991 0.948]
Actual-train:  0.86, Actual-test:  0.84
Greedy-train:  0.91, Greedy-test:  0.89
--------------------------------------------------
 1/70 | (0, 0, 0) (*): actual: [0.942 0.93  0.948], greedy: [1. 1. 1.]
 1/70 | (0, 0, 1) (*): actual: [0.748 0.914 0.742], greedy: [0.777 0.972 0.785]
 1/70 | (0, 1, 0) (*): actual: [0.81

KeyboardInterrupt: 