In [1]:
import os
import cv2
import numpy as np
import gym
import minerl
import torch as th
import torch.nn as nn
from gym import Env
from gym.spaces import Discrete, Box, MultiDiscrete
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback



In [2]:
class CustomCnnFeatureExtractor(nn.Module):

    def __init__(self, observation_space, features_dim=256):
        super(CustomCnnFeatureExtractor, self).__init__()
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        with th.no_grad():
            n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations):
        return self.linear(self.cnn(observations))

In [3]:
class CustomCnnPolicy(nn.Module):

    def __init__(self, observation_space, action_space, net_arch=[64, 64], features_dim=256):
        super(CustomCnnPolicy, self).__init__()
        self.features_extractor = CustomCnnFeatureExtractor(observation_space, features_dim=features_dim)
        self.net_arch = net_arch
        action_dim = action_space.shape[0]
        self.action_layer = nn.Sequential(nn.Linear(self.net_arch[-1], action_dim))

    def forward(self, observations, deterministic=False, use_sde=False):
        features = self.features_extractor(observations)
        action_logits = self.action_layer(features)
        return action_logits

In [4]:
class MinerlTreechopActionSpace(gym.spaces.Box):

    def __init__(self):
        super().__init__(low=0, high=1, shape=(10,), dtype=np.float32)

    def to_dict(self, action):
        return {
            "attack": int(action[0] >= 0.5),
            "back": int(action[1] >= 0.5),
            "camera": [360.0 * action[2], 360.0 * action[3]],
            "forward": int(action[4] >= 0.5),
            "jump": int(action[5] >= 0.5),
            "left": int(action[6] >= 0.5),
            "right": int(action[7] >= 0.5),
            "sneak": int(action[8] >= 0.5),
            "sprint": int(action[9] >= 0.5),
        }

    def from_dict(self, action_dict):
        action = np.zeros(10, dtype=np.float32)
        action[0] = action_dict["attack"]
        action[1] = action_dict["back"]
        action[2] = action_dict["camera"][0] / 360.0
        action[3] = action_dict["camera"][1] / 360.0
        action[4] = action_dict["forward"]
        action[5] = action_dict["jump"]
        action[6] = action_dict["left"]
        action[7] = action_dict["right"]
        action[8] = action_dict["sneak"]
        action[9] = action_dict["sprint"]
        return action

In [5]:
class MinerlTreechopEnv(gym.Env):
    
    def __init__(self, render=False):
        super().__init__()
        self.env = gym.make('MineRLTreechop-v0')
        self.observation_space = Box(low=0, high=255, shape=(64, 64, 3))
        self.action_space = MinerlTreechopActionSpace()
        self.render_enabled = render

    def step(self, action):
        camera_action = action[2]
        horizontal_degree = camera_action[0] - 89
        vertical_degree = camera_action[1] - 90
        action_dict = {
            "attack": action[0],
            "back": action[1],
            "camera": [horizontal_degree, vertical_degree],
            "forward": action[3],
            "jump": action[4],
            "left": action[5],
            "right": action[6],
            "sneak": action[7],
            "sprint": action[8]
        }
        obs, reward, done, info = self.env.step(action_dict)
        if done:
            obs = self.env.reset()

        obs = self.process_observation(obs)
        return obs, reward, done, info

    def reset(self):
        obs = self.env.reset()
        obs = self.process_observation(obs)
        return obs

    def render(self):
        if self.render_enabled:
            self.env.render()

    def close(self):
        self.env.close()

    def process_observation(self, observation):
        pov = observation['pov']
        resized_pov = (255 * pov).astype(np.uint8)
        resized_pov = cv2.resize(resized_pov, (32, 32), interpolation=cv2.INTER_CUBIC)
        return {"pov": resized_pov}

In [6]:
class RandomPolicy:
    
    def __init__(self, action_space):
        self.action_space = action_space

    def sample_action(self):
        camera_action = np.random.uniform(low=-180, high=180, size=(2,))
        action = (
            np.random.randint(2),  
            np.random.randint(2),  
            camera_action,  
            np.random.randint(2), 
            np.random.randint(2), 
            np.random.randint(2), 
            np.random.randint(2), 
            np.random.randint(2), 
            np.random.randint(2), 
        )
        return action

    def predict(self, observation):
        action = self.sample_action()
        return action

In [7]:
env = MinerlTreechopEnv(render=True)
policy = RandomPolicy(env.action_space)

In [8]:
for _ in range(5):
    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        action = policy.predict(obs)
        obs, reward, done, info = env.step(action)
        total_reward += reward
        env.render()

    print('Total Reward:', total_reward)

0it [00:00, ?it/s]

Total Reward: 0.0
Total Reward: 0.0
Total Reward: 0.0
Total Reward: 0.0
Total Reward: 0.0


In [9]:
class TrainAndLoggingCallback(BaseCallback):
    
    def __init__(self, check_freq, save_path, verbose=1):
        super().__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, f'best_model_{self.n_calls}')
            self.model.save(model_path)
        return True

In [10]:
if __name__ == "__main__":
    CHECKPOINT_DIR = './train/train_defend'
    LOG_DIR = './logs/log_defend'

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    policy_kwargs = dict(features_dim=64)

    policy = CustomCnnPolicy(env.observation_space, MinerlTreechopActionSpace(), **policy_kwargs)

    model = SAC(policy, env, learning_rate=0.0001, batch_size=256, buffer_size=50000, verbose=1, optimize_memory_usage=True)

    device = "cpu"
    model = model.to(device)
    model.learn(total_timesteps=int(200000))

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


TypeError: CustomCnnPolicy.forward() got multiple values for argument 'use_sde'

In [None]:
for _ in range(5):
    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        action = policy.predict(obs)
        obs, reward, done, info = env.step(action)
        total_reward += reward
        env.render()

    print('Total Reward:', total_reward)