In [20]:
import math
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader
from data_loader import LoadData, LoadDataShort
import numpy as np
import sys 
from data_loader import LoadDataLSTM, LoadData_StartEndLong
from networks.rnn import LSTM
from networks.mlp import MLP
from networks.trans import MLP_MULTIDECODE
import time 
import copy 
import matplotlib.pyplot as plt
from utils import load_data, create_model

In [21]:
class args:
    def __init__(self,hidden_dims=32,layers=3,PREVIOUS_FRaMES=5):
        hidden_dim = hidden_dims
        layers= layers
        PREVIOUS_FRaMES = PREVIOUS_FRaMES

def reject_outliers(data, m=2):
    return data[abs(data - np.mean(data)) < m * np.std(data)]

In [22]:
np.random.seed(seed = int(time.time() + 0))

In [29]:
#Start, End
s,e = 0,60
step_size = 10.0 / 511
print(step_size)
interval = 1
batch_size = 500
dim = 2
t_start = 0
minval = 0
maxval = 1
dataset = "double_gyre"
dim = 2 # the dimension of dataset

model_dir = "./checkpoints/double_gyre_short/double_gyre_short_1000_MLP/model_final.pth"

index = "4400"

## load test data [this is the ground truth without normalization]
test_data_dir = "./data/500_short.npy"  ## this data shape is [61, 500, 3]
gt = np.load(test_data_dir) 
print("test data shape", gt.shape, np.min(gt[0, :, 0]), np.max(gt[0, :, 0])) 

num_fm = gt.shape[0] - 1 ## number of time steps 

## Remove seeds within the offset of boundary if needed 
offset = 0.0
lower = [0, 0]
upper = [2, 1] ## double gyre 
bbox_lower = [lower[0] + offset, lower[1] + offset]
bbox_upper = [upper[0] - offset, upper[1] - offset]

x_min = np.min(gt[:,:,0]) # x_min, x_max, y_min, y_max are the boundings of the domain
x_max = np.max(gt[:,:,0]) 
y_min = np.min(gt[:,:,1]) 
y_max = np.max(gt[:,:,1]) 
minval = 0
maxval = 1
t_min = 0
t_max = gt.shape[0] - 1

seeds = gt[0, :, 0:dim]
seeds_normalized = seeds
seeds_normalized[:, 0] =   (seeds[:, 0] - x_min) / (x_max - x_min) * (maxval - minval) + minval
seeds_normalized[:, 1] = (seeds[:, 0] - y_min) / (y_max - y_min) * (maxval - minval) + minval
num_seeds = seeds.shape[0]


# seeds_copy = copy.deepcopy(seeds)

## Load model 
start = time.time()
print("start loading")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MLP()

if device == torch.device("cpu"):
    model.load_state_dict(torch.load(model_dir, map_location=device))
else:
    model.load_state_dict(torch.load(model_dir))

model.to(device)
print("Model Loaded!")
end = time.time()
print(f"Runtime of the loading model is {end - start}")

## Calculating the inference
start_time = time.time()

fms = np.zeros((num_fm+1, num_seeds, dim)) # reconstruct the results shape 
fms[0, :, :] = seeds

for ts in range(num_fm):
    results = np.zeros((num_seeds, dim)) ## save the results from the prediction for one time step
    # calculate time for each time step 
    # Note: check if my calculate is matched to the time you used for generating training data 
    time_step = (ts - t_min) / (t_max - t_min) * (maxval - minval) +  minval ## start time   ## normalized
    if ts == 0:
        seed_copy = copy.deepcopy(seeds_normalized)
    else:
        seed_copy = copy.deepcopy(pred_cpu) ## update seeds as the latest end location
    ## Note: here we can construct the dataloader for each time step; 
    # if it is slow for larger number of seeds, you can precompute the dataloader and save in an array
    # then call dataloaders[ts] to fetch the dataloader
    data = LoadDataShort(seed_copy, time_step)
    dataloader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=False)
    
    for d, batch_data in enumerate(dataloader):
        start = batch_data[0].to(device)
        t = batch_data[1].to(device)
        pred = model(start, t)
        pred_cpu = pred.detach().cpu().numpy()
        results[d*batch_size : (d+1) * batch_size] = pred_cpu
        ## normalize the results back to domain and save to fms 
        fms[ts+1, :, 0] = (results[:, 0] - minval) / (maxval - minval) * (x_max - x_min) + x_min
        fms[ts+1, :, 1] = (results[:, 1] - minval) / (maxval - minval) * (x_max - x_min) + x_min

end_time = time.time()
    
print(f"Runtime of the predictiong is {end_time - start_time}")

    ## Calculate errors
error = []
for i in range(num_seeds):
    gt_traj = gt[:, i, :]
    fm_traj = fms[:, i, :]
    seed = gt_traj[0, :]
    e = 0
    count = 0
    for g, gt_point in enumerate(gt_traj):
        fm_point = fm_traj[g, :]
        if fm_point[0] < lower[0] or fm_point[0] > upper[0] or fm_point[1] < lower[1] or fm_point[1] > upper[1]:
            continue
        else:
            dis = np.linalg.norm(gt_point[0:dim] - fm_point[0:dim])
            e = e + dis 
            count = count + 1
        e = e / count
        error.append(e)
error = np.array(error)
# # error = reject_outliers(error, m = 3)
print("model error: ")
print("max:", np.max(error))
print("min:", np.min(error))
print("mean:", np.mean(error))
print("median:", np.median(error))
# np.savetxt("error_" + dataset + ".txt", error)

## Violin Plot
fig, axs = plt.subplots()
axs.violinplot([error], showmeans=True)
# plt.ylim([0, 0.1])
plt.show()


# ### Plot 2D trajectoires 
# # plot result

fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_aspect('equal')
for n in range(seeds.shape[0] - 5):
    seed = seeds[n, :]
    if n % 2 == 0:
        ax.scatter(seed[0], seed[1])
        ax.plot(fms[:, n, 0], fms[:, n, 1], color='tab:blue', linewidth=2)
        ax.plot(gt[:, n, 0], gt[:, n, 1], color='tab:red', linewidth=1)
        # print("gt", gt[0:3, n, :])
        # print("pred", results[0:3, n, :])
plt.show()

0.019569471624266144
test data shape (61, 500, 2) 0.00231162 1.99635
start loading
Model Loaded!
Runtime of the loading model is 0.2725977897644043


NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/maanav/.local/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/maanav/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/maanav/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/maanav/.local/lib/python3.10/site-packages/torch/utils/data/dataset.py", line 53, in __getitem__
    raise NotImplementedError
NotImplementedError
