In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

import sys
sys.path.append("../Source")

import numpy as np
from Utilities.Eval import MetricsLogger
from PlotUtilities import LivePlot
from TableVsDeepModel import *

In [None]:
class ValidationMetrics:
    def __init__(self):
        self.training_avg_reward = 0.0
        self.validation_avg_reward = 0.0
        
np.random.seed(643674)
experiments = []
figures = []

constructors = [AlphaMCArrayModel, AlphaMcConv1KerasModel, SarsaConv1KerasModel] 
for constructor in constructors:
    experiment = constructor()
    training_log = MetricsLogger(experiment.method.metrics, max_length=100000)
    validation_metrics = ValidationMetrics()
    validation_log = MetricsLogger(validation_metrics, max_length=10000)
    experiments.append((experiment, training_log, validation_metrics, validation_log))

    figures.append(
        {                       
            "source": validation_log,
            "plots": [
               {
                   "metric" : "training_avg_reward",
                   "color": "b"
               },
               {
                   "metric" : "validation_avg_reward",
                   "color": "g"
               }
            ]
        }
    )

In [None]:
livePlot = LivePlot(figures)

In [None]:
plot_frequency = 201
validation_frequencey = 50

try:
    episode_count = 50001
    for i in range(episode_count):
        random_state = np.random.get_state()
        for experiment, training_log, validation_metrics, validation_log in experiments:
            np.random.set_state(random_state)
            experiment.method.run_episode()
            training_log.append(experiment.method.metrics)

            if i % validation_frequencey == validation_frequencey-1:
                validation_metrics.training_avg_reward = np.average(training_log.data["episode_reward"][-validation_frequencey:])
                validation_metrics.validation_avg_reward = experiment.validate(episode_count=50)
                validation_log.append(validation_metrics)
        
        if i % plot_frequency == plot_frequency-1:            
            livePlot.update_plot() 
except KeyboardInterrupt:
    print("Keyborad interrupt")