In [1]:
import torch
import torch.nn as nn
import yaml
import h5py
from utils import TransformerOperatorDataset
import torch.nn.functional as F
from tqdm import tqdm
import math
import time
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import os
import shutil

from models.pitt import PhysicsInformedTokenTransformer
from models.pitt import StandardPhysicsInformedTokenTransformer

from models.oformer import Encoder1D, STDecoder1D, OFormer1D, PointWiseDecoder1D
from models.fno import FNO1d
from models.deeponet import DeepONet1D

import sys

from train_pitt import evaluate, get_transformer, get_data, get_neural_operator

ModuleNotFoundError: No module named 'torch'

In [None]:
device = 'cuda' if(torch.cuda.is_available()) else 'cpu'

In [None]:
with open("./configs/pitt_config.yaml", 'r') as stream:
    config = yaml.safe_load(stream)
train_args = config['args']
prefix = train_args['flnm'] + "_" + train_args['data_name'].split("_")[0] + "_" + train_args['train_style'] + "_" + \
         train_args['embedding']
train_args['prefix'] = prefix
train_args['seed'] = 1

In [None]:
path = "{}{}_{}_{}".format(train_args['results_dir'], train_args['transformer'], train_args['neural_operator'], prefix)
f = h5py.File("{}{}".format(train_args['base_path'], train_args['data_name']), 'r')
model_name = train_args['flnm'] + '_{}'.format(train_args['transformer']) + "_{}.pt".format(train_args['seed'])
model_path = path + "/" + model_name

In [None]:
train_loader, val_loader, test_loader = get_data(f, train_args)

In [None]:
neural_operator = get_neural_operator(train_args['neural_operator'], train_args)
transformer = get_transformer(train_args['transformer'], neural_operator, train_args)

In [None]:
transformer.load_state_dict(torch.load(model_path)['model_state_dict'])

In [None]:
loss_fn = nn.L1Loss(reduction='mean')

In [None]:
test_value = evaluate(test_loader, transformer, loss_fn)

In [None]:
test_value

In [None]:
(x0, y, grid, tokens, t) = list(test_loader)[1]

In [None]:
grid.shape

In [None]:
y_pred=transformer(grid.to(device=device), tokens.to(device=device), x0.to(device=device), t.to(device=device))

In [None]:
ibatch=1
fig=plt.figure(dpi=300)
plt.plot((grid[ibatch,:]).cpu(),(y[ibatch,:,0]).cpu())
plt.plot((grid[ibatch,:]).cpu(),((y_pred.detach()[ibatch,:])).cpu())
plt.plot((grid[ibatch,:]).cpu(),((x0[ibatch,:,:]).transpose(0,1)).cpu(),'k--')

In [None]:
y = y[...,0].to(device=device)    
# Compute the loss.
loss_fn(y_pred, y).item()