In [2]:
from google.colab import drive
drive.mount('/content/drive')

FOLDERNAME = r"Neural_CDE/applications-of-NDE"

%cd drive/MyDrive/Neural_CDE/applications-of-NDE
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

Mounted at /content/drive
/content/drive/MyDrive/Neural_CDE/applications-of-NDE


In [3]:
!pip install gymnasium

#Setting up pygame for colab
!python -m pip install pygame
import os
os.environ['SDL_VIDEODRIVER']='dummy'
import pygame
pygame.display.set_mode((640,480))


# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt


#for neural CDE
!pip install torchcde

import torchcde
import torch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gymnasium
  Downloading gymnasium-0.28.1-py3-none-any.whl (925 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m925.5/925.5 KB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting farama-notifications>=0.0.1
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Collecting jax-jumpy>=1.0.0
  Downloading jax_jumpy-1.0.0-py3-none-any.whl (20 kB)
Installing collected packages: farama-notifications, jax-jumpy, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-0.28.1 jax-jumpy-1.0.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
pygame 2.3.0 (SDL 2.24.2, Python 3.9.16)
Hello from the pygame community. https://www.pygame.org/contribute.html
Installing mediapy:
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m

In [4]:
from environment import GridWorldEnv

In [15]:
size = 50
world = GridWorldEnv(render_mode="rgb_array",size = size)

world.reset()

framerate = 5 
frames = []

for t in range(1000):
        frame = world.render()
        # plt.imshow(screen)
        frames.append(frame)
        action = world.action_space.sample()
        observation, reward, terminated, truncated, info = world.step(action)
        # print(reward, done, info)
        # print(observation)
        
        if terminated:
            print("Finished after {} timesteps".format(t+1))
            break

media.show_video(frames, fps=framerate)

0
This browser does not support the video tag.


In [None]:
from models import Policy
import models
import utils
import data_utils

In [None]:
policy1 = Policy(4,4)

In [None]:
policy_weights = utils.get_weights_as_vec(policy1)

In [None]:
state_policy_time_series = data_utils.get_state_time_series_from_env(world,policy1)

Finished sampling states


In [None]:
train_X, train_y = data_utils.split_basec_on_markov(state_policy_time_series)

torch.Size([10, 21, 584])


In [None]:
train_X.shape, train_y.shape

(torch.Size([10, 20, 584]), torch.Size([10, 584]))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [None]:
num_epochs = int(1e2)

######################
# input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
# hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
# output_channels=1 because we're doing binary classification.
######################
model = models.DynamicsFunction(input_channels=584, hidden_channels=8, output_channels=584)
model.to(device)


optimizer = torch.optim.Adam(model.parameters())


######################
# Now we turn our dataset into a continuous path. We do this here via Hermite cubic spline interpolation.
# The resulting `train_coeffs` is a tensor describing the path.
# For most problems, it's probably easiest to save this tensor and treat it as the dataset.
######################
train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)

train_coeffs.to(device)
train_y.to(device)

train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch_coeffs, batch_y = batch

        pred_y = model(batch_coeffs).squeeze(-1)
        loss = torch.nn.functional.mse_loss(pred_y, batch_y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))

# test_X, test_y = get_data()
# test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
# pred_y = model(test_coeffs).squeeze(-1)
# binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)
# prediction_matches = (binary_prediction == test_y).to(test_y.dtype)
# proportion_correct = prediction_matches.sum() / test_y.size(0)
# print('Test Accuracy: {}'.format(proportion_correct))

In [None]:
def get_action(policy,obs):
  return torch.argmax(policy(obs[None,:].to(torch.float))).cpu().detach().item()


def tester(env,model,policy,t_len=20,n=10):
  policy_weights = utils.get_weights_as_vec(policy)
  test_data = []
  for i in range(n):
    
    env.reset()
    observation = torch.from_numpy(env._get_obs()[0])

    start_state = torch.hstack((observation,policy_weights))
    
    data = [start_state]

    for j in range(t_len):

            action = get_action(policy,observation)

            _, reward, terminated, truncated, info = env.step(action)
            if terminated:
                print("Finished after {} timesteps".format(t+1))

                for k in range(j,t_len):
                  data.append(torch.hstack((observation,policy_weights)))

                break
            
            observation = torch.from_numpy(env._get_obs()[0])
            data.append(torch.hstack((observation,policy_weights)))

    data = torch.stack(data,dim=0)
    print(data.shape)
    test_data.append(data)

  test_data = torch.stack(test_data,dim=0)
  print(test_data.shape)
  test_X = test_data[:,:-1]
  test_Y = test_data[:,-1]

  print(test_X.shape,test_Y.shape)

  test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_data)
  pred = model(test_coeffs)
  loss = torch.nn.functional.mse_loss(pred, test_Y)

  print(loss)
    

In [None]:
tester(world,model,policy1)