In [1]:
import random
from utils import MPNDatasetSingle, PDE, GraphCreator
from torch.utils.data import DataLoader

In [2]:
file_name = "1D_Advection_Sols_beta1.0.hdf5"
saved_folder = "/data1/zhouziyang/datasets/pdebench/1D/Advection/Train/"
variables = {"beta": 1.}
batch_size = 1

In [3]:
pde = PDE("1D_Advection", temporal_domain=(0, 2), resolution_t=201, 
          spatial_domain=[(0, 1)], resolution=[1024], variables=variables,
          reduced_resolution=4)
graph_creator = GraphCreator(pde=pde, neighbors=3, time_window=25)

nt: 201 nx: 256


In [4]:
dataset = MPNDatasetSingle(file_name, saved_folder, reduced_resolution=4, variables=variables)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [5]:
u, x, variables = next(iter(dataloader))
print(u.dtype, x.dtype)
print(variables) # ps: different type!

torch.float32 torch.float32
{'beta': tensor([1.], dtype=torch.float64)}


In [6]:
unrolled_graphs = 1
steps = [t for t in range(graph_creator.tw, # 250 - 25 - (25 * unrolled_graphs) + 1
    graph_creator.nt - graph_creator.tw - (graph_creator.tw * unrolled_graphs) + 1)]
random_steps = random.choices(steps, k=batch_size)
print("random_steps:", random_steps)

random_steps: [45]


In [7]:
# create data
data, labels = graph_creator.create_data(u, random_steps)    
print("data:", data.shape) # data: (bs, tw, nx)
print("labels:", labels.shape) # label: (bs, tw, nx)

data: torch.Size([1, 25, 256])
labels: torch.Size([1, 25, 256])


In [8]:
# create graph
graph = graph_creator.create_graph(data, labels, x, variables, random_steps)

In [9]:
# TODO convert variables type to float32

from mpnn import MPNN

model = MPNN(pde, eq_variables=variables)
pred = model(graph)
print("pred:", pred.shape)

pred: torch.Size([256, 25])


In [10]:
import torch
loss_fn = torch.nn.MSELoss(reduce="mean")
loss = loss_fn(pred, graph.y)



In [11]:
graph = graph_creator.create_next_graph(graph, pred, labels, random_steps)

In [12]:
from utils import to_PDEBench_format

print(graph.x.shape)
output = to_PDEBench_format(graph.x, batch_size, pde)
print(output.shape)

torch.Size([256, 25])
torch.Size([1, 256, 25, 1])
