In [None]:
import gym
import minerl
import logging
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from pathlib import Path
import collections
import pickle
import torch.nn.functional as F
import math
from tqdm import tqdm
from itertools import product
from collections import OrderedDict
import copy
import torch.nn.utils as torch_utils
from copy import deepcopy

logging.disable(logging.ERROR)

In [None]:
def extract_frames(video_path):
    frames = []
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break
        frames.append(frame)
    cap.release()
    return frames

def get_and_format_data(path_list):
    all_data = []

    for i, path in enumerate(path_list):
        action_values = np.load(path + "/rendered.npz", allow_pickle=True)
        min_len = len(action_values['reward'])

        frames = extract_frames(path + "/recording.mp4")
        frames = frames[:min_len]

        aligned_data = [(frames[i], {k: v[i] for k, v in action_values.items()}) for i in range(min_len)]

        aligned_data = aligned_data[:((len(aligned_data) // 128) * 128)]
        all_data.append(aligned_data)

    return [pair for video in all_data for pair in video]

In [None]:
directory_path = Path("MineRLTreechop-v0/")
path_list = ["MineRLTreechop-v0/" + f.name for f in directory_path.iterdir()]

data = get_and_format_data(path_list)


print(f"Amount of episodes: {len(data)}")
print(f"\nEach element in data is a tuple of form: (frames, action-value dict)")
print(f"Shape of pov frame: {data[0][0].shape}")

print("\nActions:")
for key in data[0][1].keys():
    print("  " + key)

In [None]:

for idx in range(20, 30):

    plt.imshow(data[idx][0])
    plt.title("Sample Frame from Dataset")
    plt.axis("off")
    plt.show()

    print(data[idx][1])


In [None]:
with open("video_and_actions.pkl", "wb") as f:
    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open("video_and_actions.pkl", "rb") as f:
    data = pickle.load(f)

# data = data[:250000]
print(len(data))

In [None]:
class ActionManager:
    """Main Minecraft action wrapper with improved camera quantization."""

    def __init__(self, device, c_action_magnitude=10):
        self.device = device
        self.c_action_magnitude = c_action_magnitude

        self.zero_action = OrderedDict([('action$attack', 0),
                                        ('action$back', 0),
                                        ('action$camera', np.array([0., 0.])),
                                        ('action$forward', 0),
                                        ('action$jump', 0),
                                        ('action$left', 0),
                                        ('action$right', 0),
                                        ('action$sneak', 0),
                                        ('action$sprint', 0)])

        # camera discretization:
        self.camera_dict = OrderedDict([
            ('action$turn_up', np.array([-c_action_magnitude, 0.])),
            ('action$turn_down', np.array([c_action_magnitude, 0.])),
            ('action$turn_left', np.array([0., -c_action_magnitude])),
            ('action$turn_right', np.array([0., c_action_magnitude]))
        ])

        self.fully_connected_no_camera = ['action$attack', 'action$back', 'action$forward', 'action$jump', 'action$left', 'action$right', 'action$sprint']
        self.camera_actions = list(self.camera_dict.keys())
        self.fully_connected = self.fully_connected_no_camera + self.camera_actions

        # Action constraints
        self.exclude = [('action$forward', 'action$back'), ('action$left', 'action$right'), ('action$attack', 'action$jump'), ('action$turn_up', 'action$turn_down', 'action$turn_left', 'action$turn_right')]
        self.only_if = [('action$sprint', 'action$forward')]

        # If more than 3 actions, these are the order by which to remove extra actions
        self.remove_size = 3
        self.remove_first_list = ['action$sprint', 'action$left', 'action$right', 'action$back',
                                  'action$turn_up', 'action$turn_down', 'action$turn_left', 'action$turn_right',
                                  'action$attack', 'action$jump', 'action$forward']

        self.fully_connected_list = list(product(range(2), repeat=len(self.fully_connected)))

        # Remove invalid action combinations
        remove = []
        for el in self.fully_connected_list:
            for tuple_ in self.exclude:
                if sum([el[self.fully_connected.index(a)] for a in tuple_]) > 1:
                    if el not in remove:
                        remove.append(el)
            for a, b in self.only_if:
                if el[self.fully_connected.index(a)] == 1 and el[self.fully_connected.index(b)] == 0:
                    if el not in remove:
                        remove.append(el)
            if sum(el) > self.remove_size:
                if el not in remove:
                    remove.append(el)

        for r in remove:
            self.fully_connected_list.remove(r)

        self.action_list = []
        for el in self.fully_connected_list:
            new_action = copy.deepcopy(self.zero_action)
            for key, value in zip(self.fully_connected, el):
                if key in self.camera_actions:
                    if value:
                        new_action['action$camera'] = self.camera_dict[key]
                else:
                    new_action[key] = value
            self.action_list.append(new_action)

        self.num_action_ids_list = [len(self.action_list)]
        self.act_continuous_size = 0

    def quantize_camera(self, camera):
        """Snap camera movement to the nearest discrete step."""
        camera_steps = np.array([-40, -20, -10, 5, 0, 5, 10, 20, 40])

        camera[0] = camera_steps[np.abs(camera_steps - camera[0]).argmin()]
        camera[1] = camera_steps[np.abs(camera_steps - camera[1]).argmin()]
        return camera

    def get_action(self, id):
        """Retrieve an action by ID"""
        a = copy.deepcopy(self.action_list[int(id)])
        a['action$camera'] += np.random.normal(0., 0.5, 2)
        a = dict(a)

        for key, _ in a.items():
            if key == 'action$camera':
                a[key] = np.array(a[key])
            else:
                a[key] = np.int64(a[key])
        return a

    def get_id(self, action):
        """Convert an action into a discrete ID."""

        action = copy.deepcopy(action)
        # action['action$camera'] = self.quantize_camera(action['action$camera'])

        for tuple_actions in self.exclude:
            if len(tuple_actions) == 2:
                a, b = tuple_actions

                if action[a] and action[b]:
                    action[b] = 0

        for a, b in self.only_if:
            if not action[b]:
                if action[a]: action[a] = 0

        # discretize 'camera':
        camera = action['action$camera']
        camera_action_amount = 0
        if - self.c_action_magnitude / 3. < camera[0] < self.c_action_magnitude / 3.:
            action['action$camera'][0] = 0.
            if - self.c_action_magnitude / 3. < camera[1] < self.c_action_magnitude / 3.:
                action['action$camera'][1] = 0.
            else:
                camera_action_amount = 1
                action['action$camera'][1] = self.c_action_magnitude * np.sign(camera[1])
        else:
            camera_action_amount = 1
            action['action$camera'][0] = self.c_action_magnitude * np.sign(camera[0])

            action['action$camera'][1] = 0.

        for a in self.remove_first_list:
            if sum([action[key] for key in self.fully_connected_no_camera]) > (self.remove_size - camera_action_amount):
                if a in self.camera_actions:
                    action['action$camera'] = np.array([0., 0.])
                    camera_action_amount = 0
                else:
                    action[a] = 0
            else:
                break

        for key in self.camera_actions:
            action[key] = 0
        for key, val in self.camera_dict.items():
            if (action['action$camera'] == val).all():
                action[key] = 1
                break

        non_separate_values = tuple(action[key] for key in self.fully_connected)
        return self.fully_connected_list.index(non_separate_values)

In [None]:


class ChopTreeAgent(nn.Module):
    def __init__(self, output_dim=10, num_frames=2):
        super(ChopTreeAgent, self).__init__()

        input_channels = 3 * num_frames
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),  # (32, 20, 20)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # (64, 9, 9)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), # (64, 7, 7)
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim)  # Output: 8 discrete and 2 continuous
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)
    

class BCDataset(Dataset):
    def __init__(self, aligned_data, num_frames=2, transform=None):
        """
        aligned_data: list of (frame, action_dict) pairs
        transform: torchvision transform to apply to each frame
        """
        self.transform = transform
        self.data = aligned_data
        self.num_frames = num_frames

        self.discrete_keys = ['forward', 'left', 'back', 'right', 'jump', 'sneak', 'sprint', 'attack']
        self.continuous_key = 'camera'

    def __len__(self):
        return len(self.data) - self.num_frames

    def __getitem__(self, idx):
        frames = []
        camera_actions = []

        for i in range(self.num_frames):
            frame, action_dict  = self.data[idx + i]
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)

            # Average camera value over frames for smoothness
            camera_action = action_dict.get(f"action${self.continuous_key}", [0.0, 0.0])

            if not isinstance(camera_action, list): camera_action = camera_action.tolist()
            camera_actions.append(camera_action)
        
        stacked_frames = torch.cat(frames, dim=0)  # Shape: (3 * num_frames, 84, 84)
        avg_camera_action = torch.tensor(camera_actions, dtype=torch.float32).mean(dim=0)

        _, action_dict = self.data[idx + self.num_frames - 1]  # Use last frame's action
        discrete_actions = [float(action_dict.get(f"action${k}", 0)) for k in self.discrete_keys]

        action_vector = torch.tensor(discrete_actions + avg_camera_action.tolist(), dtype=torch.float32)

        return stacked_frames, action_vector


In [None]:
BATCH_SIZE = 32
GAMMA = 0.99
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY = 0.99
num_epochs = 40
lr = 0.0005

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

action_manager = ActionManager(device)
output_dim = len(action_manager.action_list)

dataset = BCDataset(data, action_manager=action_manager, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

model = ChopTreeAgent(input_channels=3, output_dim=output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)
loss_fn = nn.CrossEntropyLoss(reduction='none')

In [None]:

for epoch in range(num_epochs):

    model.train()
    running_loss = 0.0
    steps = 0

    for frames, action_vectors, rewards, rc_frames in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        frames, action_vectors = frames.to(device), action_vectors.to(device)
        rewards = rewards.to(device)

        optimizer.zero_grad()
        outputs = model(frames)

        # Compute losses
        action_loss = loss_fn(outputs, action_vectors.long())

        weighted_loss = action_loss * (1 + rewards.squeeze())
        loss = weighted_loss.mean()

        loss.backward()
        torch_utils.clip_grad_norm_(model.parameters(), max_norm=10.0)

        optimizer.step()
        
        running_loss += loss.item() * frames.size(0)

    epoch_loss = running_loss / len(dataset)

    if epoch % 1 == 0:
        print(f"Epoch {epoch+1}/{num_epochs} Loss: {epoch_loss:.4f}")

In [None]:
model_path = "minecraft_tree_agent.pth"

In [None]:
torch.save(model.state_dict(), model_path)
print(f"Model weights saved to {model_path}")

In [None]:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
print(f"Model weights loaded from {model_path}")

In [None]:
def preprocess_obs(obs):
    """
    Converts the environment observation into a 4D tensor suitable for the model
    """
    frame = obs["pov"]

    transform_pipeline = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((64,64)),
        transforms.ToTensor(),  # scales pixels to [0,1]
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    frame = transform_pipeline(frame)
    return frame.unsqueeze(0) # Add a batch dimension


action_keys = ['forward', 'left', 'back', 'right', 'jump', 'sneak', 'sprint', 'attack']

def build_action_dict(actions, env):
    """
    Create a full action dictionary for the MineRL environment
    """
    ac = env.action_space.no_op()  # Start with a no-op action dictionary

    for key in action_keys:
        ac[key] = actions["action$" + key]

    ac['camera'] = actions["action$camera"]
    return ac

In [None]:
model.eval()

env = gym.make('MineRLObtainDiamondShovel-v0')

env.seed(1)
obs = env.reset()

hidden_state = None
steps = 0
done = False

try:
    while not done:
        # Preprocess observation from the environment and stack frames
        state = preprocess_obs(obs)

        # Get model output and convert to actions
        with torch.no_grad():
            output, _ = model(state)

        action_id = torch.argmax(output, dim=1).item()

        action = action_manager.get_action(action_id)

        print(action)

        action = build_action_dict(action, env)

        # Step the environment using the agent's action
        obs, reward, done, info = env.step(action)
        env.render()

        steps += 1
        if steps == 10000:
            done = True
finally:
    env.close()