# Imports

In [None]:
from models.NatureVisualEncoder import NatureVisualEncoder, conv_output_shape, pool_out_shape
# from encoders.ResNetVisualEncoder import ResNetVisualEncoder, ResNetBlock, Swish
# from src.encoders.NatureVisualAttentionEncoder import NatureVisualAttnEncoder
# from utils.encoder_utils import conv_output_shape, pool_out_shape
# from utils.rl_utils import RunningMeanStd
import torch.nn as nn
import torch
import sys
import os
import gc
import numpy as np
from utils.utils import RunningMeanStdTorch as RunningMeanStd

# Decoder Definition

#### Nature Visual Decoder

In [None]:
class NatureVisualDecoder(nn.Module):
    def __init__(
        self, height: int, width: int, initial_channels: int, output_size: int
    ):
        super().__init__()
        self.h_size = output_size
        conv_1_hw = conv_output_shape((height, width), 8, 4)
        conv_2_hw = conv_output_shape(conv_1_hw, 4, 2)
        conv_3_hw = conv_output_shape(conv_2_hw, 3, 1)
        self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64

        # linear = nn.Linear(self.h_size, self.final_flat)
        # nn.init.kaiming_normal_(linear.weight.data, nonlinearity="linear")
        
        # The above is exactly as unity did it, except dumbed down to a single use case
        # i.e don't have a bunch of ifs to support different initializers

        # self.dense = nn.Sequential(
        #     linear,
        #     nn.ReLU(),
        # )

        # self.dense = nn.Sequential(
        #     # nn.Tanh(),
        #     nn.Linear(self.h_size, self.final_flat),
        #     nn.BatchNorm1d(self.final_flat),
        #     # nn.ReLU()
        # )

        self.dense = nn.Sequential(
            nn.Linear(self.h_size, 256),
            # nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(256, self.final_flat),
            # nn.BatchNorm1d(self.final_flat),
            nn.ReLU()
        )
        
        # self.dense = nn.Sequential(
        #     nn.Linear(self.h_size, self.final_flat),
        #     nn.BatchNorm1d(self.final_flat),
        #     nn.ReLU(),
        # )

        self.deconv_layers = nn.Sequential(
            # nn.ReLU(),
            nn.ConvTranspose2d(64, 64, [3, 3], [1, 1]),
            # nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, [4, 4], [2, 2]),
            # nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, initial_channels, [8, 8], [4, 4]),
            # nn.BatchNorm2d(initial_channels),
            # nn.ReLU(),
            # nn.Sigmoid(),
            nn.Tanh(),
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        # input = self.inverse_z_score_norm(features)
        # input = self.inverse_min_max_norm(features, minval, maxval)
        hidden = self.dense(features)
        hidden = hidden.view([-1, 64, 7, 7])  # reshape to match conv_3_hw
        hidden = self.deconv_layers(hidden)
        hidden = hidden.permute([0, 2, 3, 1])
        hidden = (hidden+1)/2
        return hidden
    
    def inverse_z_score_norm(self, input):
        mu = input.mean(dim=-1, keepdim=True)
        sigma = input.var(dim=-1, keepdim=True)
        
        z = (input*sigma)+mu
        return z
    
    def inverse_min_max_norm(self, input, min, max):
        # unscaled = (input+1)/2
        shifted = input*max
        output = shifted+min

        return output



# Loss Function and Optimiser

In [None]:
# Define loss function
loss_fn = torch.nn.MSELoss()
# loss_fn = torch.nn.CrossEntropyLoss()

# Set learning rate
lr = 0.0005
alpha = 0.99
eps = 0.00001

# Manual seed for reproducibility
torch.manual_seed(0)

optimizer = torch.optim.Adam
# optimizer = torch.optim.RMSprop


# Setup

#### Using Nature

In [None]:
encoder = NatureVisualEncoder(84, 84, 3, 256).cuda()
decoder = NatureVisualDecoder(84, 84, 3, 256).cuda()
# obs_mean = RunningMeanStd(shape=(2,84,84,3))

encoder.set_is_pretraining()
try:
    # encoder.load_state_dict(torch.load('encoder_NEW.pth'))
    # decoder.load_state_dict(torch.load('decoder_NEW.pth'))
    pass

except:
    print("No data")

parameters = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

# optimiser = optimizer(parameters, lr, weight_decay=1e-05)
optimiser = optimizer(params = parameters, lr=lr, eps=eps)

### Generating a dataset to validate that the implementations work:

In [None]:
# from PIL import Image
# import numpy as np
# encoder.eval()
# decoder.eval()
# # Load the image
# # test_image = Image.open("test_img_2.png")
# # test_image = np.load("test_img_2.npy")
# test_image = np.load("test_images_1.npy")[-1]

# data0 = Image.fromarray(np.uint8(test_image[:,:,:]*255), mode = "RGB")
# data0.save("input_obs.png")
# display(data0)
# # display(test_image)
# print(test_image)
# # test_image_array = np.array(test_image.getdata())
# # print(test_image_array.shape)
# # test_image_array = np.array(test_image.getdata()).reshape(test_image.size[0], test_image.size[1], 3)

# # Convert it to float from uint8
# # test_image_array = test_image_array/255
# # test_image_array shape is now (84, 84, 3)

# # encode
# # encoded_test_img, minval, maxval = encoder(test_image_array)
# # print(test_image_array.shape)
# encoded_test_img = encoder.forward(test_image)
# print(encoded_test_img.shape)

# # decode
# decoded_test_img = decoder(encoded_test_img)
# print(decoded_test_img.shape)


# # convert to numpy and remove first dimension
# decoded_test_img_ready = decoded_test_img.squeeze(dim=0).cpu().detach().numpy()
# print(decoded_test_img_ready.shape)

# data0 = Image.fromarray(np.uint8(decoded_test_img_ready[:,:,:]*255), mode = "RGB")
# display(data0)
# data0.save("output_obs.png")

# print(encoded_test_img)
# print(decoded_test_img)
# # print(np.max(decoded_test_img, axis=-1))

# Setting Up Dataset Generation

In [None]:
# import mlagents
# from mlagents_envs.environment import UnityEnvironment
# from mlagents_envs.side_channel.engine_configuration_channel import (EngineConfigurationChannel,)
# from wrappers.UnityParallelEnvWrapper_Torch import UnityWrapper
# from mlagents_envs.base_env import ActionTuple
# import pdb
# from collections import deque
# import numpy as np
# import gc



# cwd = os.getcwd()
# def get_worker_id(filename="worker_id.dat"):
#     with open(filename, 'a+') as f:
#         f.seek(0)
#         val = int(f.read() or 0) + 1
#         f.seek(0)
#         f.truncate()
#         f.write(str(val))
#         return val
# config_channel = EngineConfigurationChannel()
# config_channel.set_configuration_parameters(time_scale=10.0)


# # This is for training in the editor. For training using an executable, setthe file name=path
# # env=UnityEnvironment(file_name=None, seed=1, side_channels=[config_channel], worker_id=0)
# # env.close()

# try:
#     env.close()
#     unity_env.close()
# except:
#     print("No envs open")


# unity_env = UnityEnvironment(file_name='./environmentExecutables/DiscreteCurriculum/DiscreteCurriculum.x86_64', worker_id=get_worker_id(), seed=np.int32(0), side_channels=[config_channel])
# # unity_env = UnityEnvironment(file_name='./unity/envs/Discrete_NoCur/Discrete_NoCur.x86_64', worker_id=get_worker_id())
# unity_env.reset()

# behaviour_name = list(unity_env.behavior_specs)[0]
# behaviour_specs = unity_env.behavior_specs[behaviour_name]

# # print(behaviour_specs.observation_specs)
# env = unity_env
# env.reset()
# # decision_steps, _ = env.get_steps(behaviour_name)
# # # np.zeros(decision_steps[0].obs[0].shape).astype(np.float32)
# # print(decision_steps[0].obs[0].shape)

# # nvec = np.zeros(decision_steps[0].obs[0].shape).astype(np.float64)
# # nvec = np.zeros(behaviour_specs.action_spec.discrete_size)
# # print(nvec)
# # empt_act = behaviour_specs.action_spec[1]
# # print(empt_act)

# decision_steps, terminal_steps = env.get_steps(behaviour_name)

# continuous = np.zeros((2,behaviour_specs.action_spec.continuous_size)).astype(np.float32)

# episode_length = 300
# num_episodes = 25
# num_datasets = 12
# episodes = deque()
# for dataset in range(num_datasets):
#     episodes.clear()
#     for ep in range(num_episodes):
#         # reset the environment
#         env.reset()
#         episode_observations = deque()
#         done = False
#         decision_steps, terminal_steps = env.get_steps(behaviour_name)
#         while not done:
#             # print(f"Step before reset: {step}")
#             # if reset_next:
#             #     env.reset()
#             #     reset_next=False
#             #     break
               

#             # Generate a random action and step the environment
#             rand1 = np.random.random_integers(0, 6)
#             rand2 = np.random.random_integers(0, 6)

#             actions = np.array([[rand1], [rand2]])
#             # action = ActionTuple(continuous=continuous, discrete=actions)
#             action = ActionTuple()
#             action.add_discrete(actions)

#             env.set_actions(behaviour_name, action)
#             env.step()
            
#             # Get an observation from the environment
#             decision_steps, terminal_steps = env.get_steps(behaviour_name)
            
#             # print(decision_steps.agent_id_to_index)
#             if len(terminal_steps)>0:
#                 steps_to_use = terminal_steps
#                 done =True
#                 # reset_next = True
                
#             else:
#                 steps_to_use = decision_steps
#                 # reset_next = False
                
#             episode_observations.append(np.float16(steps_to_use.obs[0]))

            

            
#         # episode_observations_array = np.array(episode_observations)
#         episodes.append(episode_observations)



#     # eps = []
#     # for ep in episodes:
#     #     print(ep.shape)
#     #     if ep.shape != (2000, 2, 84, 84, 3):
#     #         print(f"Episode is wrong shape")
#     #         # raise AssertionError
#     #         # print(episodes_array[i].shape)
#     #     else:
#     #         eps.append(ep)

#     # episodes_array=np.array(eps)
#     file = f"Datasets/Curriculum_Dataset_{dataset}"

#     np.save(file, episodes)

In [None]:
try:
    env.close()
    unity_env.close()
except:
    print("No envs open")

# Generating a few images to use as visual testing

In [None]:
# import mlagents
# from mlagents_envs.environment import UnityEnvironment
# from mlagents_envs.side_channel.engine_configuration_channel import (EngineConfigurationChannel,)
# from src.wrappers.UnityParallelEnvWrapper_Torch import UnityWrapper
# from mlagents_envs.base_env import ActionTuple
# import pdb
# from collections import deque
# import numpy as np
# from PIL import Image

# def get_worker_id(filename="worker_id.dat"):
#     with open(filename, 'a+') as f:
#         f.seek(0)
#         val = int(f.read() or 0) + 1
#         f.seek(0)
#         f.truncate()
#         f.write(str(val))
#         return val
# config_channel = EngineConfigurationChannel()
# config_channel.set_configuration_parameters(time_scale=100.0)


# # This is for training in the editor. For training using an executable, setthe file name=path
# # env=UnityEnvironment(file_name=None, seed=1, side_channels=[config_channel], worker_id=0)
# # env.close()

# try:
#     env.close()
#     unity_env.close()
# except:
#     print("No envs open")


# unity_env = UnityEnvironment(file_name='./src/environmentExecutables/DiscreteCurriculum/DiscreteCurriculum.x86_64', worker_id=get_worker_id(), seed=np.int32(0), side_channels=[config_channel])
# # unity_env = UnityEnvironment(file_name='./unity/envs/Discrete_NoCur/Discrete_NoCur.x86_64', worker_id=get_worker_id())
# unity_env.reset()

# behaviour_name = list(unity_env.behavior_specs)[0]
# behaviour_specs = unity_env.behavior_specs[behaviour_name]

# # print(behaviour_specs.observation_specs)
# env = unity_env
# env.reset()

# decision_steps, terminal_steps = env.get_steps(behaviour_name)

# continuous = np.zeros((2,behaviour_specs.action_spec.continuous_size)).astype(np.float32)

# episode_length = 1500
# num_episodes = 1
# episodes = deque()
# for ep in range(num_episodes):
#     # reset the environment
#     env.reset()
#     for step in range (episode_length):
        
#         # Get an observation from the environment
#         decision_steps, terminal_steps = env.get_steps(behaviour_name)
#         if len(terminal_steps)>0:
#             steps_to_use = terminal_steps
#         else:
#             steps_to_use = decision_steps
        
#         if step%50 ==0:
#             observations = steps_to_use.obs[0]
#             data0 = Image.fromarray(np.uint8(observations[1,:,:,:]*255), mode = "RGB")
#             data0.save(f"./test_imgs/out{np.uint8(step/50)}.png")
            

#         # Generate a random action and step the environment
#         rand1 = np.random.random_integers(0, 8)
#         rand2 = np.random.random_integers(0, 8)

#         actions = np.array([[rand1], [rand2]])
#         action = ActionTuple(continuous=continuous, discrete=actions)

#         env.set_actions(behaviour_name, action)
#         env.step()

# env.close()

# Loading the dataset from the saved file:

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import random
from torchvision.utils import make_grid
from matplotlib import pyplot as plt

In [None]:
def train_one_set(encoder, decoder, device, training_set, loss_fn, optimiser, grayscale):
    encoder.train()
    decoder.train()

    train_loss = []

    if grayscale:
        training_set = convert_to_grayscale(training_set)

    # apply the same conversion that we do during RL training, for consistency
    training_set = np.float32(np.uint8(training_set*255))/255

    batchloader = DataLoader(training_set, batch_size=4096, shuffle=True)
    
    for image_batch in batchloader:
        # image_batch has shape (batch_size, 84, 84, 3)
        image_batch = image_batch.to(device)
        encoded_data = encoder(image_batch)
        decoded_data = decoder(encoded_data)

        # print(f"batch shape:{image_batch.shape}")
        # print(f"Decoded shape: {decoded_data.shape}")

        # print(f"batch dtype:{image_batch.dtype}")
        # print(f"Decoded dtype: {decoded_data.dtype}")
        loss = loss_fn(decoded_data, image_batch.to(torch.float32))

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())
            
    # return np.mean(train_loss)
    return train_loss

            
def test_one_set(encoder, decoder, device, testing_set, loss_fn, grayscale):
    encoder.eval()
    decoder.eval()


    if grayscale:
        testing_set = convert_to_grayscale(testing_set)

    with torch.no_grad():
        conc_out = []
        conc_label = []

        batchloader = DataLoader(testing_set, batch_size=4096, shuffle=True)

        for image_batch in batchloader:
            # image_batch has shape (batch_size, 84, 84, 3)
            image_batch = image_batch.to(device)
            encoded_data = encoder(image_batch)
            decoded_data = decoder(encoded_data)

            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())

        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label)
        val_loss = loss_fn(conc_out, conc_label)
        print('\t partial test loss (single set): %f' % (val_loss.data))
                
    return val_loss.data

def save_validation_images(test_images, encoder, decoder, epoch):
    """
    Show input and output images during training stage
    """
    # dataset shape is (-1, 84, 84, 3

    encoded_images = encoder(test_images)
    decoded_images = decoder(encoded_images)

    np.save(f"encoder_training_results/test_images/decoded_images_{epoch}", decoded_images.detach().cpu().numpy())


def ownmin(x, y):
    if x < y:
        return x
    else:
        return y
    
def convert_to_grayscale(obs):
        obs = np.sum(obs*np.array([0.299, 0.587, 0.114]), axis=-1, keepdims=True)
        return obs
    

def process_dataset(dataset_name: str, running_mean, normalise_input_obs =  False, grayscale = False):
    np_load_old = np.load

    # modify the default parameters of np.load
    np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)

    dataset = np.load(dataset_name)

    np.load = np_load_old

    episodes = []
    shape = sum(len(ep) for ep in dataset)

    eparr = np.empty((shape,) + (2, 84, 84, 3))

    index = 0
   
    for ep in dataset:
        ep_len = len(ep)
        eparr[index:index+ep_len] = ep
        index += ep_len
    
    # running_mean.update(eparr)
    if normalise_input_obs:
        eparr = (eparr - running_mean.mean)/np.sqrt(running_mean.var)


    return eparr

def generate_means_from_dataset(dataset_name: str, running_mean):
    np_load_old = np.load

    # modify the default parameters of np.load
    np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)

    dataset = np.load(dataset_name)

    np.load = np_load_old

    episodes = []
    shape = sum(len(ep) for ep in dataset)

    eparr = np.empty((shape,) + (2, 84, 84, 3))

    index = 0
   
    for ep in dataset:
        ep_len = len(ep)
        eparr[index:index+ep_len] = ep
        index += ep_len
    
    running_mean.update(eparr)
    
    # eparr = (eparr - running_mean.mean)/np.sqrt(running_mean.var)
    
    # return eparr




In [None]:
num_epochs = 10
num_training_sets = 10
num_testing_sets = 2
all_train_losses = []
all_test_losses = []
encoder.train()
decoder.train()
obs_rms = RunningMeanStd(shape=(2,84,84,3))
normalise_obs = False
grayscale = False

if grayscale:
    shape = (-1,84,84,1)
else:
    shape = (-1,84,84,3)

# test_set_for_thesis = np.load("encoder_training_results/test_images/test_image_set.npy")

# Generate the RMS stuff first and do not modify again
if normalise_obs:
    generate_means_from_dataset(f"Datasets/Curriculum_Dataset_{0}.npy", obs_rms)


for epoch in range(num_epochs):
    epoch_train_losses = []
    print("============================================================")
    print(f"Current epoch: {epoch}")
    print("============================================================")
    for set_number in range(num_training_sets):
        dataset = process_dataset(f"Datasets/Curriculum_Dataset_{set_number}.npy", obs_rms, normalise_obs, grayscale=grayscale)
        
        # dataset shape is (num_eps, 1500, 2, 84, 84, 3), reshape to (-1, 84, 84, 3)
        dataset = np.reshape(dataset, shape).astype(np.float32)
        print(dataset.shape)
        set_loss = train_one_set(encoder, decoder, "cuda:0", dataset, loss_fn, optimiser, grayscale)
        epoch_train_losses.append(set_loss)
    # Get the mean loss for the current epoch
    # all_train_losses.append(np.mean(epoch_train_losses))
    all_train_losses.append(epoch_train_losses)

    epoch_test_losses = []
    for test_set_num in range(num_testing_sets):
        testing_set = process_dataset(f"Datasets/testing_set_{test_set_num}.npy", obs_rms, normalise_obs, grayscale=grayscale)
        testing_set = np.reshape(testing_set, shape).astype(np.float32)
        # save_validation_images(test_set_for_thesis, encoder, decoder, 5, epoch)
        test_set_loss = test_one_set(encoder, decoder, "cuda:0", testing_set, loss_fn, grayscale)
        epoch_test_losses.append(test_set_loss)
    # all_test_losses.append(np.mean(epoch_test_losses))
    all_test_losses.append(epoch_test_losses)


torch.save(encoder.state_dict(), "encoder_NEW.pth")
torch.save(decoder.state_dict(), "decoder_NEW.pth")
np.save("all_train_losses", all_train_losses)
np.save("all_test_losses", all_test_losses)


In [None]:
torch.save(encoder.state_dict(), "encoder_NEW.pth")
torch.save(decoder.state_dict(), "decoder_NEW.pth")
np.save("all_train_losses", all_train_losses)
np.save("all_test_losses", all_test_losses)

In [None]:
import matplotlib.pyplot as plt

train_losses_arr = np.array(all_train_losses)

print(train_losses_arr.shape)
print(train_losses_arr[0])
# 10 epochs of 10 datasets, each dataset with 4 loss values
# shape is (epoch, dataset, 4x loss values)
train_losses_per_dataset = np.reshape(train_losses_arr, (10, -1)).reshape(-1)

print(train_losses_per_dataset.shape)



val_losses_arr = np.array(all_test_losses).reshape(-1)

print(val_losses_arr.shape)
# Create the figure and axis objects
fig, ax = plt.subplots()

# Plot the data
# ax.plot(train_losses_per_dataset)
ax.plot(val_losses_arr)


# Add axis labels and title
ax.set_xlabel('Index')
ax.set_ylabel('Value')
ax.set_title('Array Plot')
ax.legend()

# Show the plot
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.plot(val_losses_arr)
ax.set_xlabel('Index')
ax.set_ylabel('Value')
ax.set_title('Array Plot')
ax.legend()

# Show the plot
plt.show()

# Checking the encoder imput and output after training:

In [None]:
from PIL import Image
import numpy as np

for i in range(30):
# Load the image
    test_image_name = f"./test_imgs/out{i}.png"
    test_image = Image.open(test_image_name)
    test_image_array = np.array(test_image.getdata()).reshape(test_image.size[0], test_image.size[1], 3)

    # Convert it to float from uint8
    test_image_array = test_image_array/255
    # test_image_array shape is now (84, 84, 3)


    # encode
    encoded_test_img, minval ,maxval = encoder(test_image_array)
    print(encoded_test_img.shape)

    # decode
    decoded_test_img = decoder(encoded_test_img, minval, maxval)
    print(decoded_test_img.shape)


    # convert to numpy and remove first dimension
    decoded_test_img_ready = decoded_test_img.squeeze(dim=0).cpu().detach().numpy()
    print(decoded_test_img_ready.shape)

    data0 = Image.fromarray(np.uint8(decoded_test_img_ready[:,:,:]*255), mode = "RGB")
    data0.save(f"./output_images/out{i}.png")

# Save the Encoder for use with my RL Problem:

In [None]:

torch.save(encoder.state_dict(), "encoder_NEW.pth")
torch.save(decoder.state_dict(), "decoder_NEW.pth")


In [None]:
env.close()

# Generating Encoder Training Results

In [None]:
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torch as th

In [None]:
test_set = np.load("encoder_training_results/test_images/test_image_set.npy")

# pass the above test set through autoencoder network with the trained weights
# save the output after each epoch


In [None]:
images = th.tensor(np.concatenate((test, test), axis = 0))
grid = make_grid(images.permute(0,3,1,2), nrow = 5, padding = 1)
grid = grid.permute(1,2,0)
plt.figure(dpi=170)
plt.title('Original/Reconstructed')
plt.imshow(grid)
plt.axis('off')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
from PIL import Image
import time
import torch

from IPython.display import display, clear_output
obs = np.load("obs_500.npy")
print(obs.shape)
obsqz = obs[0, 0, :, :, :]



In [None]:
print(obsqz.shape)
# fig, ax = plt.subplots()

# clear_output(wait=True)

#     # Display the current image
# plt.imshow(obsqz)
obstens = torch.tensor(obsqz)

obs_reshape = torch.reshape(obstens, (2,3,84,84))
obs_permute = torch.permute(obstens, (0,3,1,2))

assert torch.all(obs_reshape == obs_permute)




In [None]:
fig, ax = plt.subplots()
for i, image in enumerate(list(obsqz)):
    # Clear the previous plot
    clear_output(wait=True)

    # Display the current image
    plt.imshow(image, cmap="gray")


    # Show the current image
    display(fig)

    # Pause for a short duration (adjust as needed, e.g., 0.1 seconds)
    # time.sleep(0.1)
    


# Debugging Torch

In [18]:
import torch
import numpy as np

def print_iteravely(to_print):
    for i in range(to_print.shape[0]):
        print(np.sum(to_print[i]))

In [32]:
td_ero = np.load("td_er.npy")
mask = np.load("mask.npy")
masked_td_error = np.load("masked_td_error.npy")
rewards = np.load("rewards.npy")
terminated = np.load("terminated.npy")
target_max_qvals = np.load("target_max_qvals.npy")
td_error = np.load("td_error.npy")
chosen_action_qvals = np.load("chosen_action_qvals.npy")
targets =np.load("targets.npy")

error_index = 16

# print(td_ero[error_index])
# print(np.sum(masked_td_error**2, axis=(1,2))[error_index]) # the corresponding value is already wrong here
# print(np.sum(mask, axis = (1,2))[error_index]) # The corresponding value is already wrong here
# print(td_error[error_index]) # Values here seem off, all values are same and is quite large
# print(chosen_action_qvals[error_index])# Values here seem off, all values are same and is quite small


# td_error is (chosen_action_qvals - targets.detach()) so check both of those
# targets comes from the n-step bellman. If the issue is not one of the above, it lies there. Save the values used to calculate bellman and try to replicate 
print()

[[0.00646253]
 [0.00646259]
 [0.00646261]
 [0.00646262]
 [0.00646265]
 [0.00646265]
 [0.00646266]
 [0.00646266]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00646267]
 [0.00