In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from torchvision import transforms

from gridverse_torch_featureextractors.gridversefeatureextractor import GridVerseFeatureExtractor
from gridverse_utils.custom_gridverse_env import register_custom_functions
from gridverse_utils.gridversemaker import WorldMaker
import tree
import numpy as np

register_custom_functions()

In [2]:
# import wandb
# 
# wandb.init(project="self-supervised-memory-reactive")

In [3]:
GlobalHydra.instance().clear()
initialize(config_path="./config/hydra_conf", version_base=None)
cfg = compose(config_name="config")

In [4]:
def numpy_dict_to_tensor(data):
    return {k: torch.as_tensor(v) for k, v in data.items()}


In [5]:
from gym_gridverse.action import Action


def collect_episodes(env, num_episodes, max_step_len):
    episodes = []
    returns = []
    last_num_episodes = 0

    while len(episodes) < num_episodes:
        episode = [numpy_dict_to_tensor(env.reset()[0])]
        epi_return = 0

        for i in range(max_step_len - 1):
            #random action
            action = world.action_space.sample()
            obs, reward, terminated, truncated, info = env.step(action)
            # obs = numpy_dict_to_tensor(obs)
            episode.append(obs)
            epi_return += reward

            if terminated or truncated:
                break
        if terminated and not truncated:
            # repeat last observation so that all episodes have max_step_len
            last_obs = episode[-1]
            while len(episode) < max_step_len:
                episode.append(last_obs)

            episode = tree.map_structure(lambda *steps_: np.stack(steps_, axis=0), *episode)

            episodes.append(episode)
            # Assumes that positive return means agent reached good exit otherwise it reached bad exit
            # label 1 means good exit, 0 means bad exit
            returns.append(1.0 if epi_return > 0 else 0.0)

        if len(episodes) % 100 == 0 and len(episodes) > last_num_episodes:
            print(f"\tCollected{len(episodes)}")
            last_num_episodes = len(episodes)

    episodes = tree.map_structure(lambda *episodes_: np.stack(episodes_, axis=0), *episodes)
    returns = np.stack(returns, axis=0)
    return episodes, returns


In [6]:
# episodes = tree.map_structure(lambda *episodes_: torch.stack(episodes_, dim=0), *episodes)

In [7]:
from gridverse_torch_featureextractors.stacked.transformerencoder import generate_square_subsequent_mask
from torch.nn import TransformerEncoderLayer, TransformerEncoder


class ReturnPredictorFromSequence(nn.modules.Module):

    def __init__(self, observation_space: gym.spaces.Dict, config: dict):
        super().__init__()

        self.gridverse_feature_extractor = GridVerseFeatureExtractor(observation_space, config.encoder)
        self.feature_dim = config.encoder.grid_encoder.output_dim + \
                           config.encoder.agent_id_encoder.output_dim + \
                           config.encoder.items_encoder.layers[-1]
        self.rnn_hidden_dim = config.lstm_cell_size
        self.rnn = nn.LSTM(input_size=self.feature_dim, hidden_size=self.rnn_hidden_dim, batch_first=True)

        # encoder_layer = TransformerEncoderLayer(d_model=self.feature_dim, nhead=8, batch_first=True)
        # self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)

        self.linear = nn.Sequential(
            nn.Linear(in_features=self.rnn_hidden_dim, out_features=128),
            nn.Dropout(0.5),
            nn.Linear(in_features=128, out_features=64),
            nn.Dropout(0.5),
            nn.Linear(in_features=64, out_features=32),
            nn.Dropout(0.5),
            nn.Linear(in_features=32, out_features=1)
        )

    def forward(self, episodes):
        # reversing the episode so RNN process steps from the end
        # (num_batch/ num_episode, num_timesteps,..)
        episodes = tree.map_structure(lambda dict_val: torch.flip(dict_val, dims=[1]), episodes)
        num_batch, num_steps, *_ = next(iter(episodes.values())).size()

        def combine_batch_and_time_dim(tensor):
            _, _, *feature_dims = tensor.size()
            return tensor.reshape(-1, *feature_dims)

        episodes = tree.map_structure(combine_batch_and_time_dim, episodes)

        features = self.gridverse_feature_extractor(episodes)
        _, *feature_dims = features.size()

        # batch_size, num_steps, feature_dim
        features = features.reshape(num_batch, num_steps, *feature_dims)

        # mask = generate_square_subsequent_mask(num_steps)
        # transformer_output = self.transformer_encoder(features, mask=mask)
        # 
        # aggregated_output = transformer_output[:, -1, :]  # Take the last timestep
        out, (rnn_h_n, c_n) = self.rnn(features)

        outputs = self.linear(out)

        return outputs

In [8]:
class ReturnPredictorFromMemoryAndReactiveObs(nn.modules.Module):

    def __init__(self, observation_space: gym.spaces.Dict, config: dict):
        super().__init__()

        self.gridverse_feature_extractor = GridVerseFeatureExtractor(observation_space, config.encoder)
        self.feature_dim = config.encoder.grid_encoder.output_dim + \
                           config.encoder.agent_id_encoder.output_dim + \
                           config.encoder.items_encoder.layers[-1]

        self.linear = nn.Sequential(
            nn.Linear(in_features=2 * self.feature_dim, out_features=128),
            nn.Dropout(0.25),
            nn.Linear(in_features=128, out_features=64),
            nn.Dropout(0.25),
            nn.Linear(in_features=64, out_features=32),
            nn.Dropout(0.25),
            nn.Linear(in_features=32, out_features=1),
        )

    def forward(self, memory_reactive_obs):
        #(N,2,grid_verse_dim..)
        num_batch, num_obs, *_ = next(iter(memory_reactive_obs.values())).size()

        def combine_batch_and_time_dim(tensor):
            _, _, *feature_dims = tensor.size()
            return tensor.reshape(-1, *feature_dims)

        memory_reactive_obs = tree.map_structure(combine_batch_and_time_dim, memory_reactive_obs)

        features = self.gridverse_feature_extractor(memory_reactive_obs)

        _, *feature_dims = features.size()

        # batch_size, num_steps, feature_dim
        features = features.reshape(num_batch, num_obs, *feature_dims)

        # concatenating the features
        features = features.reshape(num_batch, -1)

        outputs = self.linear(features)
        return outputs



In [18]:
# env param
max_step_len = 50
num_train_episodes = 5000
num_valid_episodes = 1000
num_test_episodes = 1000
env_name = 'gv_memory.5x5'
worldMaker = WorldMaker(f'./config/gridverse_conf/{env_name}.yaml')
world = worldMaker.make_env()

# hyper param
batch_size = 128  # this is number of episodes when using sequence and num time steps when not using sequence
num_epochs = 1
learning_rate = 1e-3

# params
aggregate_stats_every_n_batch = 50
model_path = f'{env_name}.self_supervised_memory.pt'


Loading gridverse using YAML in ./config/gridverse_conf/gv_memory.5x5.yaml


In [10]:
from torch.utils.data import Dataset

from gym_gridverse.grid_object import Beacon

world.reset()
beacon_index = None
state = world.outer_env.inner_env.state.grid
for i in range(state.shape.width):
    for j in range(state.shape.width):
        if isinstance(state.objects[i][j], Beacon):
            beacon_index = state.objects[i][j].type_index()
            break

In [11]:
class EpisodeDataset(Dataset):
    def __init__(self, episodes: dict, labels: np.array, transform=None, get_suffix_sequence=False,
                 filter_by_beacon=False):
        self.episodes = episodes
        self.num_episodes, self.rollout_len, *_ = episodes['grid'].shape
        self.labels = labels
        self.transform = transform
        self.get_suffix_sequence = get_suffix_sequence
        self.filter_by_beacon = filter_by_beacon

    def __len__(self):
        return self.num_episodes * (self.rollout_len - 1)

    def __getitem__(self, idx):
        (episode_index, rollout_index) = idx

        if not self.get_suffix_sequence:
            # this is memory and reactive observation pair
            rollout_sub_seq = tree.map_structure(
                lambda episodes_value: np.stack(
                    [episodes_value[episode_index, rollout_index], episodes_value[episode_index, self.rollout_len - 1]],
                    axis=0),
                self.episodes)
        else:
            rollout_sub_seq = tree.map_structure(
                lambda episodes_value: episodes_value[episode_index, rollout_index:],
                self.episodes)

        sample = self.transform(rollout_sub_seq)

        if self.filter_by_beacon:
            if beacon_index in sample['grid'][..., 0]:
                label = self.labels[episode_index]
            else:
                # if there's no beacon 0.5
                label = 0.5
        else:
            label = self.labels[episode_index]

        return sample, label

In [12]:
from typing import List
from torch.utils.data import Sampler


class EpisodeAndSequenceIndexBatchSampler(Sampler[List[int]]):
    def __init__(self, num_episodes, rollout_len, batch_size, keep_rollout_index_constant_across_batch=True):
        super().__init__()
        self.num_episodes = num_episodes
        self.rollout_len = rollout_len
        self.batch_size = batch_size
        self.keep_rollout_index_constant_across_batch = keep_rollout_index_constant_across_batch

    def __len__(self):
        return self.num_episodes * (self.rollout_len - 1)

    def __iter__(self):
        def get_idx(num_batch, batch_size):
            if self.keep_rollout_index_constant_across_batch:
                rollout_idx = np.random.randint(0, self.rollout_len - 1, (num_batch, 1))
                rollout_idx = np.broadcast_to(rollout_idx, (num_batch, batch_size))
            else:
                rollout_idx = np.random.randint(0, self.rollout_len - 1, (num_batch, batch_size))
            episode_idx = np.random.randint(0, self.num_episodes, (num_batch, batch_size,))
            idx = np.stack([episode_idx, rollout_idx], axis=-1)
            return idx

        num_batch = len(self) // self.batch_size
        idx = get_idx(num_batch, self.batch_size)

        for _ in range(len(self) // self.batch_size):
            yield from idx.tolist()
        remaining = get_idx(1, len(self) % self.batch_size).squeeze()
        yield remaining



In [13]:
class ToTensor(object):
    """Convert dict ndarrays in sample to dict Tensors."""

    def __call__(self, sample):
        return {k: torch.as_tensor(v) for k, v in sample.items()}


composed = transforms.Compose([
    ToTensor()
])

In [14]:
def collect_and_save_episodes():
    print("Collecting train episodes...")
    train_episodes, returns = collect_episodes(world, num_train_episodes, max_step_len)
    train_dataset = EpisodeDataset(train_episodes, returns, transform=composed, get_suffix_sequence=True,
                                   filter_by_beacon=False)
    torch.save(train_dataset, f'{env_name}.train_dataset.pt')

    print("Collecting validation episodes...")
    validation_episodes, returns = collect_episodes(world, num_valid_episodes, max_step_len)
    validation_dataset = EpisodeDataset(validation_episodes, returns, transform=composed, get_suffix_sequence=True,
                                        filter_by_beacon=False)
    torch.save(validation_dataset, f'{env_name}.validation_dataset.pt')

    print("Collecting test episodes...")
    test_episodes, returns = collect_episodes(world, num_test_episodes, max_step_len)
    test_dataset = EpisodeDataset(test_episodes, returns, transform=composed, get_suffix_sequence=True,
                                  filter_by_beacon=False)
    torch.save(test_dataset, f'{env_name}.test_dataset.pt')

# collect_and_save_episodes()

In [19]:
from torch._C._profiler import ProfilerActivity


def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.

    with torch.profiler.profile(activities=[ProfilerActivity.CPU],
                                record_shapes=True,
                                profile_memory=True,
                                with_stack=True, ) as prof:
        for i, batch in enumerate(train_dataloader):
            inputs, labels = batch
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            # unsqueeze and expand labels for time dim
            labels = labels[..., None].expand(*outputs.shape)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            prof.step()
            running_loss += loss.item()

            if i % aggregate_stats_every_n_batch == (aggregate_stats_every_n_batch - 1):
                last_loss = running_loss / aggregate_stats_every_n_batch  # loss per batch
                print('  batch {} loss: {}'.format(i + 1, last_loss))
                running_loss = 0.

    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

    return last_loss

In [20]:
train_dataset = torch.load(f'{env_name}.train_dataset.pt')
train_sampler = EpisodeAndSequenceIndexBatchSampler(num_train_episodes, max_step_len, batch_size=batch_size,
                                                    keep_rollout_index_constant_across_batch=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler)

validation_dataset = torch.load(f'{env_name}.validation_dataset.pt')
validation_sampler = EpisodeAndSequenceIndexBatchSampler(num_valid_episodes, max_step_len, batch_size=batch_size,
                                                         keep_rollout_index_constant_across_batch=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_sampler=validation_sampler)

model = ReturnPredictorFromSequence(world.observation_space, cfg.algorithm)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()

epoch_number = 0

  train_dataset = torch.load(f'{env_name}.train_dataset.pt')
  validation_dataset = torch.load(f'{env_name}.validation_dataset.pt')


In [None]:
# Training
print("Starting training...")
best_vloss = float("inf")

for i in range(num_epochs):
    print('EPOCH {}:'.format(epoch_number + 1))
    model.train(True)
    avg_loss = train_one_epoch(epoch_number)

    running_vloss = 0.0
    model.eval()

    with torch.no_grad():
        for i, data in enumerate(validation_dataloader):
            inputs, labels = data

            outputs = model(inputs).squeeze()

            # unsqueeze and expand labels for time dim
            labels = labels[..., None].expand(*outputs.shape)

            loss = criterion(outputs, labels)
            running_vloss += loss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        torch.save(model.state_dict(), model_path)

    # wandb.log({'train/loss': avg_loss, 'valid/loss': avg_vloss}, epoch_number + 1)
    epoch_number += 1
    # epoch_number += i


Starting training...
EPOCH 1:
  batch 50 loss: 0.6936407166607319
  batch 100 loss: 0.6926678948316791
  batch 150 loss: 0.5647560906545629
  batch 200 loss: 0.43778205525324554
  batch 250 loss: 0.41061081452039505
  batch 300 loss: 0.4143646131538398
  batch 350 loss: 0.421618992147714
  batch 400 loss: 0.4052708839831194
  batch 450 loss: 0.39908213363340833
  batch 500 loss: 0.4078004080862766
  batch 550 loss: 0.39698186146109987
  batch 600 loss: 0.40140254549615917
  batch 650 loss: 0.3992457102310621
  batch 700 loss: 0.3916446151132194
  batch 750 loss: 0.39861348596009444
  batch 800 loss: 0.3996219066660499


In [None]:
test_dataset = torch.load(f'{env_name}.test_dataset.pt')
test_sampler = EpisodeAndSequenceIndexBatchSampler(num_test_episodes, max_step_len, batch_size=batch_size,
                                                   keep_rollout_index_constant_across_batch=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_sampler)
model = ReturnPredictorFromMemoryAndReactiveObs(world.observation_space, cfg.algorithm)
model.load_state_dict(torch.load(model_path))
model.eval()

In [None]:
res = []

In [None]:
with torch.no_grad():
    total_correct = 0
    total_samples = 0
    for i, batch in enumerate(test_dataloader):
        inputs, labels = batch

        # # removing data items with 0.5 label
        # inputs = tree.map_structure(
        #     lambda _inputs: _inputs[labels != 0.5], inputs)
        # 
        # labels = labels[labels != 0.5]

        outputs = model(inputs)

        # unsqueeze and expand labels for time dim
        labels = labels[..., None].expand(*outputs.shape)

        # Get the model's prediction after it has seen the first step.
        # The test dataset definitely contains the Memory observation.
        outputs = torch.sigmoid(outputs.squeeze())
        res.extend(outputs.cpu().numpy())

        total_correct += (labels * outputs + (1 - labels) * (1 - outputs)).sum().item()
        total_samples += outputs.numel()

    acc = total_correct / total_samples
    # wandb.log({"test/one_or_zero_class_acc": one_or_zero_class_acc})
    print(f'accuracy: {acc * 100:0.2f} %')
    print(total_samples)


In [None]:
res = np.array(res)

In [None]:
res.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
hist, bin_edges = np.histogram(res, bins=100, range=(0, 1))

In [None]:
plt.figure(figsize=(10, 6))
plt.bar(bin_edges[:-1], hist, width=np.diff(bin_edges), align="edge", edgecolor="black")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.title("Frequency Distribution")
plt.show()

In [None]:
# wandb.finish()

In [None]:
rollout_len = 50
num_episodes = 5000

In [None]:
num_learnable_memory_reactive_pair_per_episode = 1  # This could be more or less than this. 
num_memory_reactive_pair_per_episode = rollout_len - 1

In [None]:
tot_learnable_pair = num_learnable_memory_reactive_pair_per_episode * num_episodes
tot_pair = num_memory_reactive_pair_per_episode * num_episodes

In [None]:
tot_learnable_pair, tot_pair, tot_learnable_pair / tot_pair

In [None]:
noise = 1 - tot_learnable_pair / tot_pair
f'noise: {noise * 100:0.2f} %'