In [1]:
import os

os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"  # to avoid memory fragmentation
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
from typing import Tuple, Callable, Dict, Optional, Union
import math
import random
import time
from collections import deque
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
import wandb
import torchinfo
import importlib.resources
import copy
import pyrallis
from tensordict import TensorDict, from_module, from_modules
from tensordict.nn import TensorDictModule
from torchrl.data import ReplayBuffer, LazyMemmapStorage
import traceback
import saris
from saris.utils import utils, pytorch_utils, running_mean
from saris.drl.agents import sac

from saris.drl.envs import register_envs

register_envs()
torch.set_float32_matmul_precision("high")

In [2]:
os.environ["SCRIPT_DIR"] = "/home/hieule/research/saris"
os.environ["BLENDER_DIR"] = "/home/hieule/blender"
os.environ["SOURCE_DIR"] = "/home/hieule/research/saris"
os.environ["ASSETS_DIR"] = "/home/hieule/research/saris/local_assets"
os.environ["BLENDER_APP"] = "/home/hieule/blender/blender-3.3.14-linux-x64/blender"
os.environ["TMP_DIR"] = "/home/hieule/research/saris/tmp"

In [3]:
class Config:
    seed: int = 0
    ep_len: int = 1000
    eval_ep_len: int = 1000
    eval_seed: int = 0
    env_id: str = "wireless-sigmap-v0"
    sionna_config_file: str = "/home/hieule/research/saris/configs/sionna_L_multi_users.yaml"
    num_envs:int = 2
    name:str = "sac"
    load_replay_buffer:str = "/home/hieule/research/saris/local_assets/replay_buffers/SAC__L_shape_static__wireless-sigmap-v0__68763e89"
    buffer_size: int = int(80_000)
    batch_size: int = 256
    
config = Config()

In [4]:
def make_env(config, idx: int, eval_mode: bool) -> Callable:

    def thunk() -> gym.Env:

        seed = config.seed if not eval_mode else config.eval_seed
        max_episode_steps = config.ep_len if not eval_mode else config.eval_ep_len
        seed += idx
        env = gym.make(
            config.env_id,
            idx=idx,
            sionna_config_file=config.sionna_config_file,
            log_string=config.name,
            eval_mode=eval_mode,
            seed=seed,
            max_episode_steps=max_episode_steps,
        )
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
        # env = gym.wrappers.FlattenObservation(env)
        env.action_space.seed(config.seed)
        env.observation_space.seed(config.seed)

        return env

    return thunk

In [5]:

def normalize_obs(
    flat_obs: torch.Tensor,
    real_channel_rms: running_mean.RunningMeanStd,
    imag_channel_rms: running_mean.RunningMeanStd,
    epsilon: float = 1e-8,
):
    real_mean = real_channel_rms.mean.to(flat_obs.device)
    real_var = real_channel_rms.var.to(flat_obs.device)
    real_channel_len = real_channel_rms.mean.shape[0]
    real_channels = flat_obs[..., :real_channel_len]
    real_channels = (real_channels - real_mean) / torch.sqrt(real_var + epsilon)

    imag_mean = imag_channel_rms.mean.to(flat_obs.device)
    imag_var = imag_channel_rms.var.to(flat_obs.device)
    imag_channel_len = imag_channel_rms.mean.shape[0]
    imag_channels = flat_obs[..., real_channel_len : real_channel_len + imag_channel_len]
    imag_channels = (imag_channels - imag_mean) / torch.sqrt(imag_var + epsilon)

    pos = flat_obs[..., real_channel_len + imag_channel_len :]
    flat_obs = torch.cat([real_channels, imag_channels, pos], dim=-1)
    return flat_obs


def update_channel_rmss(
    flat_obs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    real_channel_rms: running_mean.RunningMeanStd,
    imag_channel_rms: running_mean.RunningMeanStd,
):
    real_channel_len = np.prod(real_channel_rms.mean.shape)
    real_channel_rms.update(flat_obs[..., :real_channel_len])
    imag_channel_len = np.prod(imag_channel_rms.mean.shape)
    imag_channel_rms.update(flat_obs[..., real_channel_len : real_channel_len + imag_channel_len])

In [6]:
envs = gym.vector.SyncVectorEnv(
    [make_env(config, i, eval_mode=False) for i in range(config.num_envs)],
)

            Sentinel is not a public part of the traitlets API.
            It was published by mistake, and may be removed in the future.
            
  warn(


In [7]:
# Create running meanstd for normalization
real_channel_len = math.prod(envs.single_observation_space[0].shape)
imag_channel_len = math.prod(envs.single_observation_space[1].shape)
real_channel_rms = running_mean.RunningMeanStd(shape=(real_channel_len,))
imag_channel_rms = running_mean.RunningMeanStd(shape=(imag_channel_len,))
obs_rmss = (real_channel_rms, imag_channel_rms)

In [8]:
obs_rmss

(RunningMeanStd(mean=tensor([0., 0., 0.,  ..., 0., 0., 0.]), var=tensor([1., 1., 1.,  ..., 1., 1., 1.]), count=1e-15),
 RunningMeanStd(mean=tensor([0., 0., 0.,  ..., 0., 0., 0.]), var=tensor([1., 1., 1.,  ..., 1., 1., 1.]), count=1e-15))

In [9]:
rb_dir = "/home/hieule/research/saris/local_assets/replay_buffers/SAC__L_shape_static__wireless-sigmap-v0__d8fa5bea"
rb = ReplayBuffer(
    storage=LazyMemmapStorage(config.buffer_size, scratch_dir=rb_dir),
    batch_size=config.batch_size,
)
rb.loads(config.load_replay_buffer)

stored_obs = []
for i, data in enumerate(rb):
    stored_obs.append(data["observations"])
    if i >= len(rb) - 1:
        break
stored_obs = np.concatenate(stored_obs, axis=0)

In [10]:
tmp_stored_obs = np.asarray(rb.storage.get("observations"))
update_channel_rmss(torch.tensor(stored_obs), obs_rmss[0], obs_rmss[1])

In [None]:
obs_rmss

(RunningMeanStd(mean=tensor([-8.4409e-08,  9.9712e-08, -1.2190e-07,  ...,  4.8128e-09,
         -4.6164e-09,  4.4355e-09]), var=tensor([5.3593e-14, 7.5684e-14, 1.1504e-13,  ..., 4.6910e-15, 4.3194e-15,
         3.9904e-15]), count=160.0),
 RunningMeanStd(mean=tensor([ 8.1348e-08, -9.6128e-08,  1.1751e-07,  ..., -1.9953e-09,
          1.9012e-09, -1.8154e-09]), var=tensor([3.3745e-14, 4.6791e-14, 6.9376e-14,  ..., 3.1586e-15, 2.9083e-15,
         2.6867e-15]), count=160.0))

In [12]:
stored_obs = []

obs, _ = envs.reset()
obs = np.concatenate([ob.reshape(ob.shape[0], -1) for ob in obs], axis=-1)
stored_obs.append(obs)

for _ in range(5):
    actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    next_obs, rewards, terminations, truncations, infos = envs.step(actions)
    obs = copy.deepcopy(next_obs)
    obs = np.concatenate([ob.reshape(ob.shape[0], -1) for ob in obs], axis=-1)
    stored_obs.append(obs)
stored_obs = np.concatenate(stored_obs, axis=0)

In [13]:
update_channel_rmss(torch.tensor(stored_obs), obs_rmss[0], obs_rmss[1])

In [14]:
obs_rmss

(RunningMeanStd(mean=tensor([-8.0655e-08,  9.5228e-08, -1.1633e-07,  ...,  6.9872e-09,
         -6.7038e-09,  6.4426e-09]), var=tensor([5.3297e-14, 7.5260e-14, 1.1439e-13,  ..., 4.5517e-15, 4.1913e-15,
         3.8721e-15]), count=172.0),
 RunningMeanStd(mean=tensor([ 8.2770e-08, -9.7888e-08,  1.1980e-07,  ..., -2.4721e-09,
          2.3606e-09, -2.2587e-09]), var=tensor([3.4917e-14, 4.8484e-14, 7.2024e-14,  ..., 3.0438e-15, 2.8027e-15,
         2.5892e-15]), count=172.0))

In [15]:
update_channel_rmss(torch.tensor(stored_obs), obs_rmss[0], obs_rmss[1])