In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from torchvision import transforms

from gridverse_torch_featureextractors.gridversefeatureextractor import GridVerseFeatureExtractor
from gridverse_utils.custom_gridverse_env import register_custom_functions
from gridverse_utils.gridversemaker import WorldMaker
import tree
import numpy as np

register_custom_functions()

In [2]:
# import wandb
# 
# wandb.init(project="self-supervised-memory-reactive")

In [3]:
GlobalHydra.instance().clear()
initialize(config_path="./config/hydra_conf", version_base=None)
cfg = compose(config_name="config")

In [4]:
def numpy_dict_to_tensor(data):
    return {k: torch.as_tensor(v) for k, v in data.items()}


In [5]:
from gym_gridverse.action import Action


def collect_episodes(env, num_episodes, max_step_len, test=False):
    episodes = []
    returns = []
    last_num_episodes = 0

    while len(episodes) < num_episodes:
        episode = [numpy_dict_to_tensor(world.reset()[0])]
        epi_return = 0

        for i in range(max_step_len - 1):
            #random action
            if test and i < 4:
                # Cannot reach exit by turning but definitely sees beacon in 5x5 grid
                action = Action.TURN_RIGHT.value
                obs, reward, terminated, truncated, info = env.step(action)
            else:
                action = world.action_space.sample()
                obs, reward, terminated, truncated, info = env.step(action)
            # obs = numpy_dict_to_tensor(obs)
            episode.append(obs)
            epi_return += reward

            if terminated or truncated:
                break
        if terminated and not truncated:
            # repeat last observation so that all episodes have max_step_len
            last_obs = episode[-1]
            while len(episode) < max_step_len:
                episode.append(last_obs)

            episode = tree.map_structure(lambda *steps_: np.stack(steps_, axis=0), *episode)

            episodes.append(episode)
            # Assumes that positive return means agent reached good exit otherwise it reached bad exit
            # label 1 means good exit, 0 means bad exit
            returns.append(1.0 if epi_return > 0 else 0.0)

        if len(episodes) % 100 == 0 and len(episodes) > last_num_episodes:
            print(f"\tCollected{len(episodes)}")
            last_num_episodes = len(episodes)

    episodes = tree.map_structure(lambda *episodes_: np.stack(episodes_, axis=0), *episodes)
    returns = np.stack(returns, axis=0)
    return episodes, returns


In [6]:
# episodes = tree.map_structure(lambda *episodes_: torch.stack(episodes_, dim=0), *episodes)

In [7]:
from gridverse_torch_featureextractors.stacked.transformerencoder import generate_square_subsequent_mask
from torch.nn import TransformerEncoderLayer, TransformerEncoder


class ReturnPredictorFromSequence(nn.modules.Module):
    def __init__(self, observation_space: gym.spaces.Dict, config: dict):
        super().__init__()

        self.gridverse_feature_extractor = GridVerseFeatureExtractor(observation_space, config.encoder)
        self.feature_dim = config.encoder.grid_encoder.output_dim + \
                           config.encoder.agent_id_encoder.output_dim + \
                           config.encoder.items_encoder.layers[-1]
        hidden_dim = config.lstm_cell_size
        self.rnn = nn.LSTM(input_size=self.feature_dim, hidden_size=hidden_dim, batch_first=True)

        # encoder_layer = TransformerEncoderLayer(d_model=self.feature_dim, nhead=8, batch_first=True)
        # self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)

        self.linear = nn.Sequential(
            nn.Linear(in_features=self.feature_dim, out_features=1),
            # nn.Dropout(0.5),
            # nn.Linear(in_features=128, out_features=64),
            # nn.Dropout(0.5),
            # nn.Linear(in_features=64, out_features=32),
            # nn.Dropout(0.5),
            # nn.Linear(in_features=32, out_features=1)
            # nn.Softmax(dim=1)
        )

    def forward(self, episodes):
        # reversing the episode so RNN process steps from the end
        # (num_batch/ num_episode, num_timesteps,..)
        episodes = tree.map_structure(lambda dict_val: torch.flip(dict_val, dims=[1]), episodes)
        num_batch, num_steps, *_ = next(iter(episodes.values())).size()

        def combine_batch_and_time_dim(tensor):
            _, _, *feature_dims = tensor.size()
            return tensor.reshape(-1, *feature_dims)

        episodes = tree.map_structure(combine_batch_and_time_dim, episodes)

        features = self.gridverse_feature_extractor(episodes)
        _, *feature_dims = features.size()

        # batch_size, num_steps, feature_dim
        features = features.reshape(num_batch, num_steps, *feature_dims)

        mask = generate_square_subsequent_mask(num_steps)
        transformer_output = self.transformer_encoder(features, mask=mask)

        aggregated_output = transformer_output[:, -1, :]  # Take the last timestep

        # out, (rnn_h_n, c_n) = self.rnn(features)
        # rnn_h_n = rnn_h_n.squeeze()

        outputs = self.linear(aggregated_output)

        return outputs



In [8]:
class ReturnPredictorFromMemoryAndReactiveObs(nn.modules.Module):

    def __init__(self, observation_space: gym.spaces.Dict, config: dict):
        super().__init__()

        self.gridverse_feature_extractor = GridVerseFeatureExtractor(observation_space, config.encoder)
        self.feature_dim = config.encoder.grid_encoder.output_dim + \
                           config.encoder.agent_id_encoder.output_dim + \
                           config.encoder.items_encoder.layers[-1]

        self.linear = nn.Sequential(
            nn.Linear(in_features=2 * self.feature_dim, out_features=128),
            nn.Dropout(0.25),
            nn.Linear(in_features=128, out_features=64),
            nn.Dropout(0.25),
            nn.Linear(in_features=64, out_features=32),
            nn.Dropout(0.25),
            nn.Linear(in_features=32, out_features=1),
        )

    def forward(self, memory_reactive_obs):
        #(N,2,grid_verse_dim)
        num_batch, num_obs, *_ = next(iter(memory_reactive_obs.values())).size()

        def combine_batch_and_time_dim(tensor):
            _, _, *feature_dims = tensor.size()
            return tensor.reshape(-1, *feature_dims)

        memory_reactive_obs = tree.map_structure(combine_batch_and_time_dim, memory_reactive_obs)

        features = self.gridverse_feature_extractor(memory_reactive_obs)

        _, *feature_dims = features.size()

        # batch_size, num_steps, feature_dim
        features = features.reshape(num_batch, num_obs, *feature_dims)
        features = features.reshape(num_batch, -1)

        outputs = self.linear(features)
        return outputs



In [9]:
from torch.utils.data import Dataset


class EpisodeDataset(Dataset):
    def __init__(self, episodes: dict, labels: np.array, transform=None):
        self.episodes = episodes
        self.num_episodes, self.rollout_len, *_ = episodes['grid'].shape
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return self.num_episodes * self.rollout_len

    def __getitem__(self, idx):
        episode_index = idx // self.rollout_len
        rollout_index = idx % self.rollout_len

        memory_step_and_reactive_step = tree.map_structure(
            lambda episodes_value: np.stack(
                [episodes_value[episode_index, rollout_index], episodes_value[episode_index, self.rollout_len - 1]],
                axis=0),
            self.episodes)
        sample = self.transform(memory_step_and_reactive_step)
        label = self.labels[episode_index]
        return sample, label

In [10]:
# def custom_dataloader_collate(batch):
#     combined_dict = {}
#     samples, labels = zip(*batch)
#     for key in samples[0].keys():
#         combined_dict[key] = torch.stack([sample[key] for sample in samples])
#     return combined_dict, labels

In [11]:
# env param
max_step_len = 50
num_train_episodes = 5000
num_valid_episodes = 1000
num_test_episodes = 1000
env_name = 'gv_memory_no_beacon.5x5'
worldMaker = WorldMaker(f'./config/gridverse_conf/{env_name}.yaml')
world = worldMaker.make_env()

# hyper param
batch_size = 128  # this is number of episodes when using sequence and num time steps when not using sequence
num_epochs = 20
learning_rate = 1e-3
cutoff_mag = 0.1

# params
aggregate_stats_every_n_batch = 50
model_path = f'{env_name}.self_supervised_memory.pt'


Loading gridverse using YAML in ./config/gridverse_conf/gv_memory_no_beacon.5x5.yaml


In [12]:
class ToTensor(object):
    """Convert dict ndarrays in sample to dict Tensors."""

    def __call__(self, sample):
        return {k: torch.as_tensor(v) for k, v in sample.items()}


composed = transforms.Compose([
    ToTensor()
])

In [13]:
def collect_and_save_episodes():
    print("Collecting train episodes...")
    train_episodes, returns = collect_episodes(world, num_train_episodes, max_step_len)
    train_dataset = EpisodeDataset(train_episodes, returns, transform=composed)
    torch.save(train_dataset, f'{env_name}.train_dataset.pt')

    print("Collecting validation episodes...")
    validation_episodes, returns = collect_episodes(world, num_valid_episodes, max_step_len)
    validation_dataset = EpisodeDataset(validation_episodes, returns, transform=composed)
    torch.save(validation_dataset, f'{env_name}.validation_dataset.pt')

    print("Collecting test episodes...")
    test_episodes, returns = collect_episodes(world, num_test_episodes, max_step_len, test=True)
    test_dataset = EpisodeDataset(test_episodes, returns, transform=composed)
    torch.save(test_dataset, f'{env_name}.test_dataset.pt')


collect_and_save_episodes()

Collecting train episodes...
	Collected100
	Collected200
	Collected300
	Collected400
	Collected500
	Collected600
	Collected700
	Collected800
	Collected900
	Collected1000
	Collected1100
	Collected1200
	Collected1300
	Collected1400
	Collected1500
	Collected1600
	Collected1700
	Collected1800
	Collected1900
	Collected2000
	Collected2100
	Collected2200
	Collected2300
	Collected2400
	Collected2500
	Collected2600
	Collected2700
	Collected2800
	Collected2900
	Collected3000
	Collected3100
	Collected3200
	Collected3300
	Collected3400
	Collected3500
	Collected3600
	Collected3700
	Collected3800
	Collected3900
	Collected4000
	Collected4100
	Collected4200
	Collected4300
	Collected4400
	Collected4500
	Collected4600
	Collected4700
	Collected4800
	Collected4900
	Collected5000
Collecting validation episodes...
	Collected100
	Collected200
	Collected300
	Collected400
	Collected500
	Collected600
	Collected700
	Collected800
	Collected900
	Collected1000
Collecting test episodes...
	Collected100
	Collected200

In [None]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.

    for i, batch in enumerate(train_dataloader):
        inputs, labels = batch
        optimizer.zero_grad()
        outputs = model(inputs).squeeze()
        # # unsqueeze and expand labels for time dim
        # labels = labels[..., None].expand(*outputs.shape)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % aggregate_stats_every_n_batch == (aggregate_stats_every_n_batch - 1):
            last_loss = running_loss / aggregate_stats_every_n_batch  # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    return last_loss

In [None]:
train_dataset = torch.load(f'{env_name}.train_dataset.pt')
validation_dataset = torch.load(f'{env_name}.validation_dataset.pt')

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, )
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, )

model = ReturnPredictorFromMemoryAndReactiveObs(world.observation_space, cfg.algorithm)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()

epoch_number = 0

In [None]:
# Training
print("Starting training...")
best_vloss = float("inf")

for i in range(num_epochs):
    print('EPOCH {}:'.format(epoch_number + 1))
    model.train(True)
    avg_loss = train_one_epoch(epoch_number)

    running_vloss = 0.0
    model.eval()

    with torch.no_grad():
        for i, data in enumerate(validation_dataloader):
            inputs, labels = data
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels)
            running_vloss += loss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        torch.save(model.state_dict(), model_path)

    # wandb.log({'train/loss': avg_loss, 'valid/loss': avg_vloss}, epoch_number + 1)
    epoch_number += 1
    # epoch_number += i


In [None]:
test_dataset = torch.load(f'{env_name}.test_dataset.pt')
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, )
model = ReturnPredictorFromMemoryAndReactiveObs(world.observation_space, cfg.algorithm)
model.load_state_dict(torch.load(model_path))
model.eval()

In [None]:
res = []

In [None]:
with torch.no_grad():
    total_correct = 0
    total_samples = 0
    for i, batch in enumerate(test_dataloader):
        inputs, labels = batch
        outputs = model(inputs)

        # Get the model's prediction after it has seen the first step.
        # The test dataset definitely contains the Memory observation.
        outputs = torch.sigmoid(outputs.squeeze())
        res.extend(outputs.cpu().numpy())
        # note this doesn't calculate the accuracy for the class when the model is unsure.
        # this is accuracy for 0,1 class

        pos_threshold = 1 - cutoff_mag
        neg_threshold = cutoff_mag

        outputs[outputs < neg_threshold] = 0
        outputs[outputs > pos_threshold] = 1
        outputs[torch.logical_and(outputs >= neg_threshold, outputs <= pos_threshold)] = 0.5
        total_correct += torch.sum(outputs == labels).item()
        total_samples += outputs.numel()
    one_or_zero_class_acc = total_correct / total_samples
    # wandb.log({"test/one_or_zero_class_acc": one_or_zero_class_acc})
    print(one_or_zero_class_acc)


In [None]:
res = np.array(res)

In [None]:
res.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
hist, bin_edges = np.histogram(res, bins=100, range=(0, 1))

In [None]:
plt.figure(figsize=(10, 6))
plt.bar(bin_edges[:-1], hist, width=np.diff(bin_edges), align="edge", edgecolor="black")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.title("Frequency Distribution")
plt.show()

In [None]:
# wandb.finish()

In [None]:
rollout_len = 50
num_episodes = 5000

In [None]:
num_learnable_memory_reactive_pair_per_episode = 1  # This could be more or less than this. 
num_memory_reactive_pair_per_episode = rollout_len - 1

In [None]:
tot_learnable_pair = num_learnable_memory_reactive_pair_per_episode * num_episodes
tot_pair = num_memory_reactive_pair_per_episode * num_episodes

In [None]:
tot_learnable_pair, tot_pair, tot_learnable_pair / tot_pair

In [None]:
noise = 1 - tot_learnable_pair / tot_pair
f'noise: {noise * 100:0.2f} %'