In [5]:
import os
# change directory in order to be able to import python scripts from the common package
if os.getcwd().endswith('crafter_rl'): os.chdir(os.path.expanduser(".."))

In [6]:
from env import VanillaEnv, CrafterReplayBuffer, create_local_semantic
import numpy as np
import matplotlib.pyplot as plt
import os
import random
import crafter
from policy import ActorFCNet
import torch
import cv2

In [9]:
folder_name: str = './crafter_rl/experiments/231015-084433/'

def play_episode(model: ActorFCNet, seed=1)->bool:
    """
    returns true of the episode is concidered solved
    """
    env = crafter.Env( size=(512, 512))

    model.eval()
    done = False
    episode_return = 0
    obs = env.reset()
    frames = [obs]
    semantic = np.zeros((1, 9, 9), dtype=np.float32)
    while not done:       
        action_logits = model.forward(torch.FloatTensor(semantic).unsqueeze(0), contrastive=False)
        action = torch.argmax(action_logits)
        obs, rewards, done, info = env.step(action.item())
        semantic = create_local_semantic(
                info['semantic'], info['player_pos'][0], info['player_pos'][1],
                info['inventory']['health'], info['inventory']['food'],
                info['inventory']['drink'], info['inventory']['energy'],
                info['inventory']['sapling'], info['inventory']['wood'],
                info['inventory']['stone'], info['inventory']['coal'],
                info['inventory']['iron'], info['inventory']['diamond'],
                info['inventory']['wood_pickaxe'], info['inventory']['stone_pickaxe'],
                info['inventory']['iron_pickaxe'], info['inventory']['wood_sword'],
                info['inventory']['stone_sword'], info['inventory']['iron_sword'],
            )
        semantic = semantic.astype(np.float32).reshape((1, 9, 9))
        episode_return += rewards
        frames.append(obs)
    return frames, episode_return



models_names = [n for n in os.listdir(folder_name) if n.endswith('.pth')]
for models_name in models_names:
    model = ActorFCNet()

    ckp = torch.load(folder_name + models_name, map_location=torch.device('cpu'))
    model.load_state_dict(ckp['state_dict'])

    episode_return = 0
    while episode_return < 2:
        frames, episode_return = np.array(play_episode(model, VanillaEnv(seed=1, semantic=True)))
        print(episode_return)

    fps = 8
    out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'DIVX'), fps, (512,512))

    for i in range(len(frames)):
        rgb_img = cv2.cvtColor(frames[i], cv2.COLOR_RGB2BGR)
        out.write(rgb_img)
    out.release()

  frames, episode_return = np.array(play_episode(model, VanillaEnv(seed=1, semantic=True)))


2.099999999999999
