In [None]:
import torch
import torch.nn as nn
from utils import UNetDatasetSingle, UNetDatasetMult
from unet import UNet1d, UNet2d

In [None]:
import h5py
data_path = "/data1/zhouziyang/datasets/pdebench/2D/Burgers/2D_Burgers_Nu0.001.hdf5"
f = h5py.File(data_path, 'r')
seed_list = sorted(f.keys())
print(seed_list[::1])
f.close()

In [None]:
# parameter
flnm = "2D_Burgers_Nu0.001.hdf5"
base_path = "/data1/zhouziyang/datasets/pdebench/2D/Burgers/"
reduced_resolution = 1
reduced_resolution_t = 1
reduced_batch = 1
initial_step = 10
t_train = 101 # The number of time step in a sample 
unroll_step = 20 # unrolled time step for the pushforward trick
batch_size = 8
learning_rate = 1e-4

In [None]:
# dataset and dataloader
train_data = UNetDatasetMult(flnm,
    saved_folder=base_path,
    reduced_resolution=reduced_resolution,
    reduced_resolution_t=reduced_resolution_t,
    reduced_batch=reduced_batch,
    initial_step=initial_step)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)

In [None]:
# Model, Loss and Optimizer
device = "cuda:2"
model = UNet2d(2*initial_step, 2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
loss_fn = nn.MSELoss(reduction="mean")

In [None]:
x_sample, y_sample = next(iter(train_loader))
# training config (pushforward)
t_train = min(t_train, y_sample.shape[-2]) # 41
unroll_step = t_train - 1 if t_train - unroll_step < 1 else unroll_step # 20: 总要留出一个时间步的解作为target

# train
train_l2_step = 0
train_l2_full = 0
for x, y in train_loader:
    # one iteration
    loss = 0
    x = x.to(device) # input tensor (bs, x, t, v): (16, 256, 10, 1)
    y = y.to(device) # target tensor (bs, x, t, v): (16, 256, 41, 1)

    pred = y[..., :initial_step, :] # (16, 256, 10, 1)
    inp_shape = list(x.shape) # [16, 256, 10, 1]
    inp_shape = inp_shape[:-2] # [16, 256]
    inp_shape.append(-1) # [16, 256] -> [16, 256, -1]

    # Autoregressive Loop
    for t in range(initial_step, t_train): # range(10, 40)
        # Reshape input tensor into [b, x1, ..., xd, t_init*v]
        inp = x.reshape(inp_shape) # (16, 256, 10)
        temp_shape = [0, -1] 
        temp_shape.extend([i for i in range(1,len(inp.shape)-1)]) # [0, -1] -> [0, -1, 1]
        inp = inp.permute(temp_shape) # (16, 10, 256)
        # Extract target of current time step
        target = y[..., t:t+1, :] # (16, 256, 1, 1)
        # run model to predict
        temp_shape = [0]
        temp_shape.extend([i for i in range(2,len(inp.shape))]) # [0] -> [0, 2]
        temp_shape.append(1) # [0, 2] -> [0, 2, 1] 
        if t < t_train - unroll_step: # 21
            with torch.no_grad():
                im = model(inp).permute(temp_shape).unsqueeze(-2) # (16, 1, 256) -> (16, 256, 1) -> (16, 256, 1, 1)
        else:
            im = model(inp).permute(temp_shape).unsqueeze(-2) # (16, 10, 256) -> (16, 1, 256) -> (16, 256, 1, 1)
            # compute loss
            loss += loss_fn(im.reshape(batch_size, -1), target.reshape(batch_size, -1))
        # Concatenate the prediction at current time step into the prediction tensor
        pred = torch.cat((pred, im), -2) # (16, 256, 11, 1) 
        # construct the input of next time step 
        x = torch.cat((x[..., 1:, :], im), dim=-2) # (16, 256, 10, 1)
        
    train_l2_step += loss.item() # step loss
    print(train_l2_step)
    _batch = y.size(0) # 16
    _y = y[..., :t_train, :] # (16, 256, 41, 1)
    l2_full = loss_fn(pred.reshape(_batch, -1), _y.reshape(_batch, -1))
    train_l2_full += l2_full.item() # total loss (我认为和step loss是一回事)
    # update weight of model
    optimizer.zero_grad()
    loss.backward() # 只用unrolled time step之后的损失更新模型参数
    optimizer.step()

In [None]:
# not pushforward
x, y = next(iter(train_loader))
loss = 0
x = x.to(device) # (bs, x, t, v): (16, 256, 10, 1)
y = y.to(device) # (bs, x, t, v): (16, 256, 41, 1)
pred = y[..., :initial_step, :] # (16, 256, 10, 1)
inp_shape = list(x.shape) # [16, 256, 10, 1]
inp_shape = inp_shape[:-2] # [16, 256]
inp_shape.append(-1) # [16, 256] -> [16, 256, -1]
t_train = min(t_train, y.shape[-2])

# autoregressive loop
for t in range(initial_step, t_train):
    inp = y[..., t-initial_step:t, :].reshape(inp_shape) # (16, 256, 10, 1) -> (16, 256, 10)
    temp_shape = [0, -1]
    temp_shape.extend([i for i in range(1, len(inp.shape)-1)]) # [0, -1] -> [0, -1, 1]
    inp = inp.permute(temp_shape) # (16, 256, 10) -> (16, 10, 256)

    target = y[..., t:t+1, :] # (16, 256, 1, 1)

    temp_shape = [0]
    temp_shape.extend([i for i in range(2,len(inp.shape))]) # [0] -> [0, 2]
    temp_shape.append(1) # [0, 2, 1]
    im = model(inp).permute(temp_shape).unsqueeze(-2) # (16, 10, 256) -> (16, 1, 256) -> (16, 256, 1, 1)

    loss += loss_fn(im.reshape(batch_size, -1), target.reshape(batch_size, -1))

    pred = torch.cat((pred, im), -2)