###Everything combined in one notebook



In [1]:
!pip install wandb
!pip install git+git://github.com/openai/baselines
!pip install git+git://github.com/ankeshanand/pytorch-a2c-ppo-acktr-gail
!pip install git+git://github.com/mila-iqia/atari-representation-learning.git

Collecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/00/8e/d43984196a0fa8ef961ae3dce91ada52ae7747fbf39d41f5743c27152d97/wandb-0.9.2-py2.py3-none-any.whl (1.4MB)
[K     |████████████████████████████████| 1.4MB 3.5MB/s 
[?25hCollecting shortuuid>=0.5.0
  Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl
Collecting GitPython>=1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/8c/f9/c315aa88e51fabdc08e91b333cfefb255aff04a2ee96d632c32cb19180c9/GitPython-3.1.3-py3-none-any.whl (451kB)
[K     |████████████████████████████████| 460kB 19.7MB/s 
Collecting sentry-sdk>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/a9/84/b4cfabda293bdf40742510c1145534e063f8e629446e619b9d8b2d549390/sentry_sdk-0.16.0-py2.py3-none-any.whl (107kB)
[K     |████████████████████████████████| 112kB 24.5MB/s 
[?25hCollecting configparser>=3.8.1
  Downloading h

Collecting git+git://github.com/mila-iqia/atari-representation-learning.git
  Cloning git://github.com/mila-iqia/atari-representation-learning.git to /tmp/pip-req-build-vqed1fo2
  Running command git clone -q git://github.com/mila-iqia/atari-representation-learning.git /tmp/pip-req-build-vqed1fo2
Building wheels for collected packages: atariari
  Building wheel for atariari (setup.py) ... [?25l[?25hdone
  Created wheel for atariari: filename=atariari-0.0.1-cp36-none-any.whl size=46714 sha256=732bf8a0834f5b6f9fb5d9619ff9e9f90fca5822128b0035d52edc123bd4b356
  Stored in directory: /tmp/pip-ephem-wheel-cache-0z3b6tqb/wheels/3d/69/51/5e436e5ae566c5b4dec5c53e65396d516459877a42a11d7aa4
Successfully built atariari
Installing collected packages: atariari
Successfully installed atariari-0.0.1


In [None]:
import torch
from google.colab import drive

drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'path/to/dir'

In [2]:
import os
import gym
import torch
import numpy as np
from atariari.benchmark.episodes import get_episodes
from atariari.benchmark.probe import ProbeTrainer, LinearProbe
from atariari.benchmark.utils import EarlyStopping
from atariari.benchmark.wrapper import AtariARIWrapper
from atariari.methods.encoders import NatureCNN
from atariari.methods.stdim import InfoNCESpatioTemporalTrainer
from atariari.benchmark.envs import GrayscaleWrapper


# Based on https://github.com/mila-iqia/atari-representation-learning


class AtariARIHandler:
    def __init__(self, args, wandb,
                 observation_shape=torch.Size([1, 210, 160]),
                 model_dir=base_dir):
        self.args = args  # object containing all needed parameters for AtariARI to work
        self.wandb = wandb  # wandb is not used in this project, but still a (mock) object is needed
        self.observation_shape = observation_shape  # observation tensor shape of Atari observations
        self.model_dir = model_dir + ('' if model_dir[-1] == '/' else '/')  # dir where encoder/probe models are stored
        self.encoder_model_path = self.model_dir + args.env_name + '-encoder.pt'  # file name of encoder model
        self.probe_model_path = self.model_dir + args.env_name + '-{}.pt'  # file name format for probe models
        self.probe_trainer = None  # object of type ProbeTrainer
        gym_env = AtariARIWrapper(gym.make(self.args.env_name))  # Create Atari env
        self.gym_env = GrayscaleWrapper(gym_env)

    def get_gym_env(self):
        return self.gym_env

    def train_encoder(self, encoder):
        device = torch.device("cuda:" + str(self.args.cuda_id) if torch.cuda.is_available() else "cpu")
        encoder = encoder.to(device)
        tr_episodes, val_episodes = get_episodes(env_name=self.args.env_name, steps=self.args.pretraining_steps,
                                                 train_mode="train_encoder")
        print('Obtained episodes for training encoder')

        torch.set_num_threads(1)
        config = {}
        config.update(vars(self.args))
        config['obs_space'] = encoder.input_channels

        print('Training encoder...')
        trainer = InfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=self.wandb)
        trainer.train(tr_episodes, val_episodes)
        print('Done training encoder')
        return encoder

    def load_encoder(self):
        encoder = NatureCNN(self.observation_shape[0], self.args)
        if os.path.isfile(self.encoder_model_path):
            print('Encoder model exists, loading weights')
            encoder.load_state_dict(torch.load(self.encoder_model_path, map_location=self.args.device))
            encoder.eval()
        else:
            print(f'No encoder model with name {self.encoder_model_path} found, training a new encoder')
            encoder = self.train_encoder(encoder)
            torch.save(encoder.state_dict(), self.encoder_model_path)
        return encoder

    def set_probes(self, probe_trainer: ProbeTrainer, probes: dict, labels):
        probe_trainer.probes = probes

        probe_trainer.early_stoppers = {
            k: EarlyStopping(patience=probe_trainer.patience, verbose=False, name=k + "_probe",
                             save_dir=probe_trainer.save_dir)
            for k in labels.keys()}

        probe_trainer.optimizers = {k: torch.optim.Adam(list(probe_trainer.probes[k].parameters()),
                                                        eps=1e-5, lr=probe_trainer.lr) for k in labels.keys()}
        probe_trainer.schedulers = {
            k: torch.optim.lr_scheduler.ReduceLROnPlateau(probe_trainer.optimizers[k], patience=5, factor=0.2,
                                                          verbose=True,
                                                          mode='max', min_lr=1e-5) for k in labels.keys()}

    def train_probes(self, probe_trainer: ProbeTrainer):
        print('Obtaining episodes for probe training')
        tr_episodes, val_episodes, \
        tr_labels, val_labels, \
        test_episodes, test_labels = get_episodes(env_name=self.args.env_name, steps=self.args.probe_steps)
        print('Training probes')
        probe_trainer.train(tr_episodes, val_episodes, tr_labels, val_labels)
        print('Probe training complete')

        for i, k in enumerate(probe_trainer.probes):
            torch.save(probe_trainer.probes[k].state_dict(), self.probe_model_path.format(k))
        print(f'Saved {len(probe_trainer.probes)} probe models to {self.model_dir} directory')

    def load_probes(self, encoder, labels):
        probe_trainer = ProbeTrainer(encoder=encoder, representation_len=encoder.feature_size, epochs=self.args.epochs)
        probes = {}
        for i, k in enumerate(labels):
            path = self.probe_model_path.format(k)
            if not os.path.isfile(path):
                print(f'Probe model for label {k} not found')
                break
            probes[k] = LinearProbe(input_dim=probe_trainer.feature_size,
                                    num_classes=probe_trainer.num_classes).to(probe_trainer.device)
            print(f'- Loading probe model for label {k}')
            probes[k].load_state_dict(torch.load(path, map_location=self.args.device))
        if len(probes) == len(labels):
            self.set_probes(probe_trainer, probes, labels)
        else:
            print('Training new probe models')
            self.train_probes(probe_trainer)
        return probe_trainer

    def probe_setup(self, ignore_labels=None):
        if ignore_labels is None:
            ignore_labels = []
        labels = self.gym_env.labels()
        for lbl in ignore_labels:
            try:
                del labels[lbl]
                print(f'Ignoring label {lbl}')
            except:
                print(f'Ignore label {lbl}: no such label')

        encoder = self.load_encoder()
        self.probe_trainer = self.load_probes(encoder, labels)
        print('Encoder en probe setup complete')

    def predict(self, obs):
        assert self.probe_trainer is not None, 'ProbeTrainer is not initialized, call probetrainer_setup() first'
        pt = self.probe_trainer
        obs = obs.reshape(1, 1, 210, 160)
        obs = torch.from_numpy(obs).float()
        with torch.no_grad():
            pt.encoder.to(self.args.device)
            obs = obs.to(self.args.device)
            f = pt.encoder(obs).detach()
        probes = pt.probes
        preds = {}
        for i, k in enumerate(probes):
            probes[k].to(self.args.device)
            p = probes[k](f)
            preds[k] = np.argmax(p.cpu().detach().numpy(), axis=1)[0]
        return preds


In [3]:
import json
import random
import pickle

'''
The MDPBuilder class can be used to collect observations and build a Markov Decision Process. The information needed to
create an MDP is stored in a structure of dicts where every distinct observation (in the form of a dict with 
label - value pairs) is treated as a state and assigned a unique numeric ID. The MDPBuilder keeps a count of how many 
times the same transitions occur between states. The MDPBuilder state can be directly saved to a file (using pickle),
so building of an MDP can be continued later. An MDP can be generated using either PRISM or JANI (JSON) format. The 
probability of a transition is calculated as the fraction of the total number of transitions of the action. The 
accuracy of these probabilities ofcourse depends on the number of observations used to create the MDP. For usage 
examples, see the run_atari.py file.
'''


class MDPBuilder:
    def __init__(self, labels, actions, log=True, probability_decimals=2):
        self.fresh_state_id = 0
        self.prev_state_id = None
        self.states_ids = {}
        self.states = {}
        self.initial_states = set()
        self.labels = labels
        self.actions = actions
        self.log = log
        self.probability_decimals = probability_decimals

    def num_states(self):
        return len(self.states)

    def get_random_action(self):
        return random.choice(self.actions)

    def get_fresh_state_id(self):
        r = self.fresh_state_id
        self.fresh_state_id += 1
        return r

    def insert_state(self, label_dict, state_rep_hash):
        fid = self.get_fresh_state_id()
        self.log_message(f'Adding new state with ID {fid}')
        self.states_ids[state_rep_hash] = fid
        s = {
            'labels': label_dict,
            'transitions': {
                a: {} for a in self.actions
            }
        }
        self.states[fid] = s
        return fid

    def get_state_id(self, label_dict):
        state_rep_hash = hash(frozenset(label_dict.items()))
        s_id = self.states_ids.get(state_rep_hash)
        return s_id if s_id is not None else self.insert_state(label_dict, state_rep_hash)

    def add_transition(self, from_id, to_id, action):
        if action not in self.actions:
            raise Exception(f'Action value {action} was provided, but does not occur in MDP actions')
        s = self.states.get(from_id)
        if to_id in s['transitions'][action]:
            s['transitions'][action][to_id] += 1
        else:
            s['transitions'][action][to_id] = 1

    def add_state_info(self, label_dict, action):
        label_dict = {k: v for k, v in label_dict.items() if k in self.labels}
        s_id = self.get_state_id(label_dict)
        if self.prev_state_id is not None:
            self.add_transition(self.prev_state_id, s_id, action)
        else:
            self.initial_states.add(s_id)
        self.prev_state_id = s_id

    def log_message(self, msg):
        if self.log:
            print(msg)

    def save_builder_to_file(self, path):
        info = {
            'fresh_state_id': self.fresh_state_id,
            'states_ids': self.states_ids,
            'states': self.states,
            'initial_states': self.initial_states,
            'labels': self.labels,
            'actions': self.actions
        }
        pickle.dump(info, open(path, 'wb'))

    def load_from_file(self, path):
        info = pickle.load(open(path, 'rb'))
        self.fresh_state_id = info['fresh_state_id']
        self.states_ids = info['states_ids']
        self.states = info['states']
        self.initial_states = info['initial_states']
        self.labels = info['labels']
        self.actions = info['actions']
        self.restart()

    def restart(self):
        self.prev_state_id = None

    def get_prism_commands(self):
        commands = []
        for s_id, info in self.states.items():
            guard_str = '[] '
            guards = []
            for l, v in info['labels'].items():
                guards.append(f'{l}={v}')
            guard_str += (' & '.join(guards))
            guard_str += ' -> '

            for action, transitions in info['transitions'].items():
                n_transitions = len(transitions)
                total_n = 1
                if n_transitions == 0:
                    continue
                elif n_transitions > 1:
                    total_n = sum([n for n in transitions.values()])

                update_probs = []
                for t_id, n in transitions.items():
                    updates = []
                    for l, v in self.states[t_id]['labels'].items():
                        updates.append(f"({l}'={v})")

                    prob_str = ''
                    if n_transitions > 1:
                        prob_str = f'{round(n / total_n, self.probability_decimals)} : '
                    update_probs.append(f'{prob_str}{" & ".join(updates)}')
                commands.append(f'{guard_str}{" + ".join(update_probs)};')
        return commands

    def get_jani_guard(self, label_list):
        if len(label_list) == 0:
            return {}
        guard = {
            'op': '=',
            'left': label_list[0][0],
            'right': int(label_list[0][1])
        }

        if len(label_list) == 1:
            return guard
        return {
            'op': '∧',
            'left': self.get_jani_guard(label_list[1:]),
            'right': guard
        }

    def get_jani_automaton(self):
        aut = {
            'name': 'atari_game',
            'locations': [{'name': 'l'}],
            'initial-locations': ['l'],
            'edges': []
        }

        for s_id, info in self.states.items():
            guards = {'exp': self.get_jani_guard(list(info['labels'].items()))}
            for action, transitions in info['transitions'].items():
                n_transitions = len(transitions)
                total_n = 1
                if n_transitions == 0:
                    continue
                elif n_transitions > 1:
                    total_n = sum([n for n in transitions.values()])

                edge = {
                    'location': 'l',
                    'action': str(action),
                    'guard': guards,
                    'destinations': []
                }
                for t_id, n in transitions.items():
                    dest = {
                        'location': 'l',
                        'probability': {
                            'exp': round(n / total_n, self.probability_decimals) if n_transitions > 1 else 1
                        },
                        'assignments': []
                    }
                    for l, v in self.states[t_id]['labels'].items():
                        dest['assignments'].append({
                            'ref': l,
                            'value': int(v)
                        })
                    edge['destinations'].append(dest)
                aut['edges'].append(edge)
        return aut

    def build_jani_model(self, file_path):
        init_state = self.states[next(iter(self.initial_states))]['labels']
        jani = {
            'jani-version': 1,
            'name': 'atari-jani-model',
            'type': 'mdp',
            'features': ['derived-operators'],
            'actions': [{'name': str(action)} for action in self.actions],
            'variables': [{'name': label,
                           'type': {'kind': 'bounded', 'base': 'int', 'lower-bound': 0, 'upper-bound': 256}}
                          for label in self.labels],
            'restrict-initial': {'exp': self.get_jani_guard(list(init_state.items()))},
            'properties': [],
            'automata': [self.get_jani_automaton()],
            'system': {
                'elements': [{'automaton': 'atari_game'}]
            }
        }
        json.dump(jani, open(file_path, "w"))

    def build_prism_model(self, file_path):
        f = open(file_path, "w")
        f.write('mdp\n\n')
        f.write('module atari_game\n')
        indent_fmt = '    {}\n'
        var_fmt = indent_fmt.format('{} : [0..256] init {};')
        init_state = self.states[next(iter(self.initial_states))]['labels']
        for l, v in init_state.items():
            f.write(var_fmt.format(l, v))
        commands = self.get_prism_commands()
        for c in commands:
            f.write(indent_fmt.format(c))
        f.write('endmodule')
        f.close()

    def build_model_file(self, file_path, format='prism'):
        if format == 'prism':
            self.build_prism_model(file_path)
        elif format == 'jani':
            self.build_jani_model(file_path)
        else:
            raise Exception(f"Model format '{format}' is not supported")


In [19]:
class Object(object):
    pass


# Mock command line arguments needed for AtariARI
args = Object()
args.method = 'infonce-stdim'
args.feature_size = 256
args.no_downsample = True
args.end_with_relu = False
args.env_name = 'PongNoFrameskip-v4'  # Atari env name of the game to use
args.pretraining_steps = 100000  # Steps to use for training encoder
args.probe_steps = 50000  # Steps to use for training linear probes
args.cuda_id = '0'
args.epochs = 100  # Number of training epochs
args.batch_size = 64
args.patience = 15
args.lr = 3e-4
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mock wandb object needed for AtariARI
wandb = Object()
wandb.run = Object()
wandb.run.dir = 'wandb'
wandb.log = lambda a, step, commit: None

In [20]:
ignore_labels = ['player_x', 'enemy_x']  # Labels to ignore when training/using probes
handler = AtariARIHandler(args, wandb)
gym_env = handler.get_gym_env()
handler.probe_setup(ignore_labels)  # Train encoder/probe models, or load them if exist

labels_to_use = ['ball_x', 'ball_y', 'player_y']  # Labels to use for MDP, other labels are not collected

# Actions to use for MDP, for Pong only 3 actions are relevant: none (0), up (2) and down (5)
# These correspond to actions defined by Arcade Learning Environment
actions_to_use = [0, 2, 5]

Ignoring label player_x
Ignoring label enemy_x
Encoder model exists, loading weights
- Loading probe model for label player_y
- Loading probe model for label enemy_y
- Loading probe model for label ball_x
- Loading probe model for label ball_y
- Loading probe model for label enemy_score
- Loading probe model for label player_score
Encoder en probe setup complete


In [None]:
from atariari.benchmark.episodes import get_episodes

tr_episodes, val_episodes, \
    tr_labels, val_labels, \
    test_episodes, test_labels = get_episodes(env_name=args.env_name, steps=args.probe_steps, train_mode="probe")
print('Testing:')
test_acc, test_f1score = handler.probe_trainer.test(test_episodes, test_labels)

-------Collecting samples----------
Deleting player_x for being too low in entropy! Sorry, dood!
Deleting enemy_x for being too low in entropy! Sorry, dood!
Duplicates: 1144, Test Len: 8792
Testing:
Total Steps: 8792
In our paper, we report F1 scores and accuracies averaged across each category. 
              That is, we take a mean across all state variables in a category to get the average score for that category.
              Then we average all the category averages to get the final score that we report per game for each method. 
              These scores are called 'across_categories_avg_acc' and 'across_categories_avg_f1' respectively
              We do this to prevent categories with large number of state variables dominating the mean F1 score.
              
Epoch: Test
	 player_y_acc:   0.5099
	 enemy_y_acc:   0.8977
	 ball_x_acc:   0.8510
	 ball_y_acc:   0.8986
	 enemy_score_acc:   0.9867
	 player_score_acc:   0.9992
	 small_object_localization_avg_acc:   0.8748
	 agent_l

In [21]:
mdp_builder = MDPBuilder(labels_to_use, actions_to_use, log=False)
# Optionally load pickled MDPBuilder from file
# mdp_builder.load_from_file(base_dir + '/mdp_3_000_000_steps.pkl')

In [None]:
mdp_builder = MDPBuilder(labels_to_use, actions_to_use, log=False)
obs = gym_env.reset()  # Reset env before use
n_steps = 1000  # No. of steps to run on the Gym environment
for i in range(n_steps):
    # gym_env.render()  # Optionally render the game to the screen
    action = mdp_builder.get_random_action()  # Get an action to apply on the env
    obs, reward, done, info = gym_env.step(action)  # Perform step on env using given action
    prediction = handler.predict(obs)  # Obtain prediction using observation from env
    mdp_builder.add_state_info(prediction, action)  # Add the predicted info to the MDPBuilder

    # Instead of prediction, MDP can be built using ground truth available in info['labels']
    # mdp_builder.add_state_info(info['labels'], action)

    if done:  # Game finished, reset env to continue
        print(f'Resetting env (step {i})')
        gym_env.reset()
        mdp_builder.restart()  # Treat next observed state repr as initial state
print(f'DONE: found {mdp_builder.num_states()} states')
# mdp_builder.save_builder_to_file(base_dir + '/mdp_demo.pkl')

In [23]:
mdp_builder.states

{0: {'labels': {'ball_x': 88, 'ball_y': 53, 'player_y': 106},
  'transitions': {0: {}, 2: {1: 1}, 5: {}}},
 1: {'labels': {'ball_x': 147, 'ball_y': 48, 'player_y': 94},
  'transitions': {0: {1: 2, 3: 1}, 2: {}, 5: {2: 1}}},
 2: {'labels': {'ball_x': 147, 'ball_y': 48, 'player_y': 104},
  'transitions': {0: {2: 1, 403: 1}, 2: {2: 1, 8: 1}, 5: {1: 1}}},
 3: {'labels': {'ball_x': 163, 'ball_y': 48, 'player_y': 94},
  'transitions': {0: {2: 1, 3: 1}, 2: {3: 1}, 5: {4: 1, 398: 1}}},
 4: {'labels': {'ball_x': 163, 'ball_y': 53, 'player_y': 92},
  'transitions': {0: {4: 2, 5: 2}, 2: {5: 1}, 5: {4: 3}}},
 5: {'labels': {'ball_x': 174, 'ball_y': 66, 'player_y': 203},
  'transitions': {0: {3: 1, 5: 2, 6: 1, 124: 1},
   2: {5: 5, 7: 1},
   5: {4: 1, 5: 2}}},
 6: {'labels': {'ball_x': 163, 'ball_y': 48, 'player_y': 92},
  'transitions': {0: {6: 1}, 2: {4: 1}, 5: {6: 2}}},
 7: {'labels': {'ball_x': 163, 'ball_y': 48, 'player_y': 110},
  'transitions': {0: {7: 1, 14: 1, 322: 1, 396: 1, 454: 1},
   2

In [24]:
mdp_builder.build_model_file(base_dir + '/mdp.pm')