In [None]:
from customenv import CustomDoorKey

from minigrid.wrappers import ImgObsWrapper
from stable_baselines3 import PPO

from customfeatureextractor import CNNFeaturesExtractor, CustomFeatureExtractor, CustomImgObsWrapper

from callback import CustomRewardCallback
from plot import make_plot

import matplotlib.pyplot as plt

# size is grid size
# intermediate reward determines if picking up key/opening door gives reward
# randomimze goal determines if goal should be randomized (cell on last column)
env = CustomDoorKey(size=8, intermediate_reward=True, randomize_goal=True, render_mode = "rgb")
default_env = ImgObsWrapper(env)
custom_env = CustomImgObsWrapper(env)

policy_kwargs = dict(
    features_extractor_class=CNNFeaturesExtractor,
    features_extractor_kwargs=dict(features_dim=128, regularization = False),
)

custom_policy_kwargs = dict(
    features_extractor_class=CustomFeatureExtractor,
    features_extractor_kwargs=dict(cnn_features_dim=128, mlp_features_dim=32),
)

max_reward = 0.9
callback = CustomRewardCallback(check_freq=1000, reward_threshold=max_reward)  # set callback

# custom behavior, Babak can temper with this
model = PPO("MultiInputPolicy", custom_env, policy_kwargs=custom_policy_kwargs, verbose=1)
model.learn(2e5, callback=callback)

# default behavior, Baldur can use this
model = PPO("CnnPolicy", default_env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(2e5, callback=callback)



model.save("MODEL_NAME")

# the plot function also saves the plot
plt.savefig('IMAGE_NAME.png')
