In [1]:
import sys
sys.path.append('/home/renbo/Desktop/New-Ops/neuraloperator-branches/neuraloperator')

In [2]:
import torch

import os
import time
from configmypy import ConfigPipeline, YamlConfig, ArgparseConfig
from neuralop import get_model
from neuralop import Trainer
from neuralop.training import setup
from neuralop.datasets.navier_stokes import load_navier_stokes_pt
from neuralop.utils import get_wandb_api_key, count_params, get_project_root, set_seed
from neuralop import LpLoss, H1Loss
from neuralop.models.spectral_convolution import FactorizedSpectralConv
import torch.nn as nn
from torch.cuda import amp 


In [18]:
#from torch.ao.quantization import QConfigMapping
#from torch.ao.quantization.qconfig_mapping import get_default_qconfig_mapping
#from torch.ao.quantization.fx.custom_config import PrepareCustomConfig

# Note that this is temporary, we'll expose these functions to torch.ao.quantization after official releasee
#from torch.quantization.quantize_fx import prepare_fx, convert_fx

# ignore complexhalf warnings
import warnings
warnings.filterwarnings("ignore")

def get_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    size = os.path.getsize("temp.p")/1e6
    os.remove('temp.p')
    return size

def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            ## simple module
            #new = new.from_float(module)
            setattr(model, n, new)


# Read the configuration
config_name = 'default'
#config_folder = os.path.join(get_project_root(), 'config')
config_folder = os.path.join('..', 'config')
config_file_name = 'load_8layer_config.yaml'

pipe = ConfigPipeline([YamlConfig(config_file_name, config_name=config_name, config_folder=config_folder),
                       ArgparseConfig(infer_types=True, config_name=None, config_file=None),
                       YamlConfig(config_folder=config_folder)
                      ])
config = pipe.read_conf()
config_name = pipe.steps[-1].config_name

# Set seed
if 'seed' in config and config.seed:
    print('setting seed to', config.seed)
    set_seed(config.seed)

#Set-up distributed communication, if using
device, is_logger = setup(config)

# Make sure we only print information when needed
config.verbose = config.verbose and is_logger

# Loading the Navier-Stokes dataset in 128x128 resolution
train_loader, test_loaders, output_encoder = load_navier_stokes_pt(
        config.data.folder, train_resolution=config.data.train_resolution, n_train=config.data.n_train, batch_size=config.data.batch_size, 
        positional_encoding=config.data.positional_encoding,
        test_resolutions=config.data.test_resolutions, n_tests=config.data.n_tests, test_batch_sizes=config.data.test_batch_sizes,
        encode_input=config.data.encode_input, encode_output=config.data.encode_output,
        num_workers=config.data.num_workers, pin_memory=config.data.pin_memory, persistent_workers=config.data.persistent_workers
        )
model = get_model(config)
model = model.to(device)

setting seed to 123
UnitGaussianNormalizer init on 10000, reducing over [0, 1, 2, 3], samples of shape [1, 128, 128].
   Mean and std of shape torch.Size([1, 1, 1]), eps=1e-05
Given argument key='skip' that is not in TFNO2d's signature.
Keyword argument non_linearity not specified for model TFNO2d, using default=<built-in function gelu>.
Keyword argument fno_skip not specified for model TFNO2d, using default=linear.
Keyword argument mlp_skip not specified for model TFNO2d, using default=soft-gating.
Keyword argument decomposition_kwargs not specified for model TFNO2d, using default={}.


In [4]:
l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
eval_losses={'h1': h1loss, 'l2': l2loss}

In [19]:
#load model from checkpoint
checkpoint_path = os.path.join('../checkpoints', 'full_precision.pt')
trainer = Trainer(model, n_epochs=config.opt.n_epochs,
                  device=device,
                  mg_patching_levels=config.patching.levels,
                  mg_patching_padding=config.patching.padding,
                  mg_patching_stitching=config.patching.stitching,
                  wandb_log=config.wandb.log,
                  amp_autocast=config.opt.amp_autocast,
                  precision_schedule=config.opt.precision_schedule,
                  log_test_interval=config.wandb.log_test_interval,
                  log_output=config.wandb.log_output,
                  use_distributed=config.distributed.use_distributed,
                  verbose=config.verbose and is_logger)

#Create the optimizer
optimizer = torch.optim.Adam(model.parameters(), 
                                lr=config.opt.learning_rate, 
                                weight_decay=config.opt.weight_decay)

# load model from dict
model_load_epoch = -1
trainer.load_model_checkpoint(model_load_epoch, model, optimizer, load_path=checkpoint_path)

Training on regular inputs (no multi-grid patching).


-1

#### We try inference on full precision model and with added AMP. AMP makes inference slightly faster.

In [8]:
for loader_name, loader in test_loaders.items():
    to_log_output = True
    msg = ''
    errors = trainer.evaluate(model, eval_losses, loader, output_encoder, log_prefix=loader_name)

    for loss_name, loss_value in errors.items():
        msg += f', {loss_name}={loss_value:.4f}'

    print(msg)


, 128_h1=0.0105, 128_l2=0.0041


In [9]:
for loader_name, loader in test_loaders.items():
    msg = ''
    with amp.autocast(enabled=True):
        errors = trainer.evaluate(model, eval_losses, loader, output_encoder, log_prefix=loader_name)

    for loss_name, loss_value in errors.items():
        msg += f', {loss_name}={loss_value:.4f}'
print(msg)

, 128_h1=0.0112, 128_l2=0.0041


#### Here, we are casting the model to half-precision and with added AMP, we perform inference. The results degrades significantly in half-precision when trained in full-precision.

In [15]:
import copy 
model_fp16 = copy.deepcopy(model)
model_fp16.fno_blocks.convs.half_prec_fourier = False
model_fp16.fno_blocks.convs.half_prec_inverse = True

for loader_name, loader in test_loaders.items():
    msg = ''
    with amp.autocast(enabled=True):
        errors = trainer.evaluate(model_fp16, eval_losses, loader, output_encoder, log_prefix=loader_name)
    for loss_name, loss_value in errors.items():
        msg += f', {loss_name}={loss_value:.4f}'
print(msg)

, 128_h1=0.0690, 128_l2=0.0577


In [21]:
# try without amp
for loader_name, loader in test_loaders.items():
    msg = ''
    errors = trainer.evaluate(model_fp16, eval_losses, loader, output_encoder, log_prefix=loader_name)
    for loss_name, loss_value in errors.items():
        msg += f', {loss_name}={loss_value:.4f}'
print(msg)

KeyboardInterrupt: 