In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
# Make sure we're in the right directory
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

In [3]:
print('Hi :)')

Hi :)


In [6]:
import sys
import h5py
import json
import logging
import torch
from torch import Tensor
import numpy as np
from einops import repeat, rearrange
import time
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict
import climart.data_wrangling.constants as constants
from climart.models.interface import get_model, is_gnn, is_graph_net
from climart.models.column_handler import ColumnPreprocesser
from climart.models.GNs.constants import NODES, EDGES
from climart.data_wrangling.constants import LEVELS, LAYERS, GLOBALS, TRAIN_YEARS, TEST_YEARS

In [7]:
torch.cuda.get_device_name(0)

'Tesla T4'

In [8]:
plt.rcParams['figure.figsize'] = [20, 8]  # general matplotlib parameters
plt.rcParams['figure.dpi'] = 70 
np.set_printoptions(suppress=True, threshold=sys.maxsize)

In [14]:
hdf5_years_dir = "/home/rtu0715/clim_test/climart/ClimART_DATA/inputs"
year = 2011
h5_path = os.path.join(hdf5_years_dir, str(year) + '.h5')

In [11]:
def get_one_snapshot(h5_path: str, batch_size: int, exp_type= 'pristine'):
    n_batches = int(8192/batch_size)
    data = []
    with h5py.File(h5_path, 'r') as h5f:
        globs = np.array(h5f[GLOBALS])
        lays = np.array(h5f[LAYERS][..., :14]) if exp_type == 'pristine' else np.array(h5f[LAYERS])
        levs = np.array(h5f[LEVELS])
        for i in range(n_batches):
            dslice = slice(i*batch_size, (i+1)*batch_size)
            data.append(
                {GLOBALS: globs[dslice], LAYERS: lays[dslice], LEVELS: levs[dslice]}
            )
    del globs, lays, levs
    print(f"{n_batches} batches of size {batch_size}, amounting to {sum([d[GLOBALS].shape[0] for d in data])} data points.")
    return data


In [12]:
def gn_input_dict_renamer_level_nodes(batch: Dict[str, Tensor], device):
    batch[NODES] = torch.FloatTensor(batch[LEVELS]).to(device)
    batch[GLOBALS] = torch.FloatTensor(batch[GLOBALS]).to(device)
    batch[EDGES] = torch.FloatTensor(batch[LAYERS])
    batch[EDGES] = repeat(batch[EDGES], "b e d -> b (repeat e) d", repeat=2).to(device)  # bidirectional edges
    return batch

def mlp_transform(batch: Dict[str, Tensor], device):
    return torch.cat([torch.FloatTensor(v).reshape((v.shape[0], -1)).to(device) for v in batch.values()], dim=1)

def to_torch(batch, device):
    return {k: torch.FloatTensor(v).to(device) for k, v in batch.items()}

def to_torch2(batch, device):
    return {
            LEVELS: torch.FloatTensor(batch[LEVELS]).to(device),
            GLOBALS: torch.FloatTensor(batch[GLOBALS]).to(device),
            LAYERS: torch.FloatTensor(batch[LAYERS]).to(device)
    }
def cnn_transform(batch, device):
    X_levels = torch.FloatTensor(batch[LEVELS])

    X_layers = rearrange(torch.nn.functional.pad(rearrange(torch.FloatTensor(batch[LAYERS]), 'b c f -> ()b c f'), (0,0,1,0),\
            mode='reflect'), '()b c f ->b c f')
    X_global = repeat(torch.FloatTensor(batch[GLOBALS]), 'b f ->b c f', c = 50)

    X = torch.cat((X_levels, X_layers, X_global), -1)
    return rearrange(X, 'b c f -> b f c').to(device)

in_transform_funcs = {
    "mlp": mlp_transform,
    'gn': gn_input_dict_renamer_level_nodes, 
    'gcn': to_torch2,
    'cnn': cnn_transform
}

In [13]:
def reload_and_speed_test_model(ckpt: str, data, device, model_dir='scripts/out', init_runs = 2, avg_runs_over = 10):
    """ init_runs are run but not accounted for in the benchmark, i.e. warm-up runs. """
    model_ckpt = torch.load(f"{model_dir}/{ckpt}.pkl", map_location=torch.device(device))
    params = model_ckpt['hyper_params']
    net_params = model_ckpt['model_params']
    model_type = params['model']
    
    model_kwargs = dict()
    if is_gnn(model_type) or is_graph_net(model_type):
        model_kwargs['column_preprocesser'] = ColumnPreprocesser(
            n_layers=49, input_dims=net_params['input_dim'], **params['preprocessing_dict']
        )
    model = get_model(model_name=model_type, device=device, verbose=False, **model_ckpt['model_params'], **model_kwargs)
    model.load_state_dict(model_ckpt['model'])
    model = model.to(device).float()
    model.eval()
    
    transform = in_transform_funcs[model_type.lower().replace('+readout', '')]
    
    loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
    for logger in loggers:
        logger.setLevel(logging.WARNING)
    
    times = []
    total_times = []

    for it in range(init_runs + avg_runs_over):
        batch_time = batch_total_time = 0
        for i, batch in enumerate(data, 1):
            batch = batch.copy()
            total_start = torch.cuda.Event(enable_timing=True)
            start_t = torch.cuda.Event(enable_timing=True)
            end_t = torch.cuda.Event(enable_timing=True)
            total_start.record()
            batch_model = transform(batch, device=device)
            start_t.record()
            with torch.no_grad():
                Y = model(batch_model)
            end_t.record()
            
            torch.cuda.synchronize()
            
            forward_t = start_t.elapsed_time(end_t)/1000
            total_t = total_start.elapsed_time(end_t)/1000
            batch_time += forward_t
            batch_total_time += total_t
        print(f"Forward pass needed {batch_time} out of {batch_total_time} sec. for {i} batches.")
        if it > init_runs:
            times += [batch_time]
            total_times += [batch_total_time]

    times, total_times = np.array(times), np.array(total_times)
    mean1, std1 = times.mean(), times.std()
    mean2, std2 = total_times.mean(), total_times.std()

    print(f"Forward time for {model_type} on {device}: ${mean1:.4f} \pm {std1:.3f}$, Total time with reshaping: ${mean2:.4f} \pm {std2:.3f}$ (batch-size={batch_size})")

In [15]:
model_dir1 = "./out"
#model_dir2 = "scripts/out"

## Speed test CPU + batch size = 512

In [30]:
batch_size = 512
data = get_one_snapshot(h5_path, batch_size, exp_type='pristine')

16 batches of size 512, amounting to 8192 data points.


In [31]:
mlp_model_ckpt = "0.4956valRMSE_97ep_MLP_1990+1999+2003train_2005val_Z_7seed_17h56m_on_Jul_02_3bmatnru"
gn_model_ckpt = "0.2939valRMSE_98ep_GN+READOUT_1990+1999+2003train_2005val_Z_7seed_14h44m_on_Jul_04_i55gi6yh"
gcn_model_ckpt = "0.4915valRMSE_99ep_GCN+READOUT_1990+1999+2003train_2005val_Z_7seed_04h04m_on_Jul_03_710kxisa"
cnn_model_ckpt = "0.2313valRMSE_69ep_CNN_1990+1999+2003train_2005val_Z_7seed_12h35m_on_Jul_01_1wig78v3"

In [37]:
lgcn_model_ckpt='0.4620valRMSE_98ep_GCN+READOUT_1990+1999+2003train_2005val_Z_7seed_10h24m_on_Jul_05_1u722tu4'

In [32]:
reload_and_speed_test_model(mlp_model_ckpt, data, model_dir=model_dir1, device = 'cpu')

Forward pass needed 0.2282984638214111 out of 0.23498835182189942 sec. for 16 batches.
Forward pass needed 0.22360643100738525 out of 0.22769027328491212 sec. for 16 batches.
Forward pass needed 0.22023692893981933 out of 0.22412825489044189 sec. for 16 batches.
Forward pass needed 0.22410038471221924 out of 0.22865635204315185 sec. for 16 batches.
Forward pass needed 0.22252521514892576 out of 0.22639648056030273 sec. for 16 batches.
Forward pass needed 0.23427791976928714 out of 0.23842531490325927 sec. for 16 batches.
Forward pass needed 0.22952774620056157 out of 0.23342845058441158 sec. for 16 batches.
Forward pass needed 0.22650313758850096 out of 0.23054457473754883 sec. for 16 batches.
Forward pass needed 0.2409109449386596 out of 0.24494028759002687 sec. for 16 batches.
Forward pass needed 0.22100764846801757 out of 0.22484057521820067 sec. for 16 batches.
Forward pass needed 0.22682713603973387 out of 0.2308333120346069 sec. for 16 batches.
Forward pass needed 0.2240093107223

In [33]:
reload_and_speed_test_model(gn_model_ckpt, data, model_dir=model_dir1, device = 'cpu')

Forward pass needed 6.73901318359375 out of 6.746077270507811 sec. for 16 batches.
Forward pass needed 6.8713291015625 out of 6.878454833984374 sec. for 16 batches.
Forward pass needed 6.88237322998047 out of 6.8885115661621095 sec. for 16 batches.
Forward pass needed 6.852412048339843 out of 6.860756500244141 sec. for 16 batches.
Forward pass needed 6.956266845703125 out of 6.965012115478515 sec. for 16 batches.
Forward pass needed 7.109958709716799 out of 7.117385253906249 sec. for 16 batches.
Forward pass needed 6.606150787353516 out of 6.612471893310547 sec. for 16 batches.
Forward pass needed 6.9238864746093745 out of 6.931957275390624 sec. for 16 batches.
Forward pass needed 6.999532989501954 out of 7.007055755615234 sec. for 16 batches.
Forward pass needed 7.1657098693847665 out of 7.1747335510253905 sec. for 16 batches.
Forward pass needed 7.050624267578126 out of 7.059027160644531 sec. for 16 batches.
Forward pass needed 7.058336700439453 out of 7.06601644897461 sec. for 16 ba

In [34]:
reload_and_speed_test_model(gcn_model_ckpt, data, model_dir=model_dir1, device = 'cpu')

Forward pass needed 4.292908050537109 out of 4.294265930175781 sec. for 16 batches.
Forward pass needed 4.198937057495116 out of 4.200250701904297 sec. for 16 batches.
Forward pass needed 4.291575408935547 out of 4.292938537597656 sec. for 16 batches.
Forward pass needed 4.242414306640624 out of 4.244147369384765 sec. for 16 batches.
Forward pass needed 4.224200836181641 out of 4.225582580566405 sec. for 16 batches.
Forward pass needed 4.319510223388672 out of 4.320876647949218 sec. for 16 batches.
Forward pass needed 4.314929595947266 out of 4.316792449951172 sec. for 16 batches.
Forward pass needed 4.300395660400391 out of 4.302194244384766 sec. for 16 batches.
Forward pass needed 4.266609985351562 out of 4.268330596923828 sec. for 16 batches.
Forward pass needed 4.2894942016601565 out of 4.2908681640625 sec. for 16 batches.
Forward pass needed 4.402841583251953 out of 4.404250091552734 sec. for 16 batches.
Forward pass needed 4.32027685546875 out of 4.322093048095704 sec. for 16 bat

In [36]:
reload_and_speed_test_model(cnn_model_ckpt, data, model_dir=model_dir1, device = 'cpu')

Forward pass needed 3.1916438751220704 out of 3.2218335113525387 sec. for 16 batches.
Forward pass needed 3.080736129760742 out of 3.108499145507812 sec. for 16 batches.
Forward pass needed 3.0425771026611326 out of 3.0699398040771477 sec. for 16 batches.
Forward pass needed 3.0204265899658203 out of 3.048200881958008 sec. for 16 batches.
Forward pass needed 3.053827392578125 out of 3.081831481933594 sec. for 16 batches.
Forward pass needed 3.0105337829589844 out of 3.0380930023193353 sec. for 16 batches.
Forward pass needed 3.029181121826172 out of 3.0565822296142575 sec. for 16 batches.
Forward pass needed 3.0309021911621095 out of 3.0588779144287113 sec. for 16 batches.
Forward pass needed 3.020282302856445 out of 3.0476426239013668 sec. for 16 batches.
Forward pass needed 3.018077804565429 out of 3.0454842071533204 sec. for 16 batches.
Forward pass needed 3.025551055908203 out of 3.052905197143555 sec. for 16 batches.
Forward pass needed 3.0506876678466797 out of 3.07842822265625 s

# Pristine, GPU, 512 batch-size

In [None]:
reload_and_speed_test_model(mlp_model_ckpt, data, device = 'cuda')

In [76]:
reload_and_speed_test_model(gn_model_ckpt, data, model_dir=model_dir1, device = 'cuda')

INFO:GN_readout_MLP: No inverse normalization for outputs is used.


_------------------------- True
Forward pass needed 3.5268044738769526 out of 3.544875119209289 sec. for 16 batches.
Forward pass needed 0.21128582382202146 out of 0.23019414138793948 sec. for 16 batches.
Forward pass needed 0.1954006071090698 out of 0.21209398365020754 sec. for 16 batches.
Forward pass needed 0.19211660861969 out of 0.20890521812438964 sec. for 16 batches.
Forward pass needed 0.1913344955444336 out of 0.20809833717346188 sec. for 16 batches.
Forward pass needed 0.19368988990783692 out of 0.211979266166687 sec. for 16 batches.
Forward pass needed 0.1910831346511841 out of 0.20786380767822263 sec. for 16 batches.
Forward pass needed 0.19085414409637452 out of 0.20764262294769287 sec. for 16 batches.
Forward pass needed 0.1915508804321289 out of 0.20820479869842523 sec. for 16 batches.
Forward pass needed 0.1938556776046753 out of 0.21040230464935303 sec. for 16 batches.
Forward pass needed 0.19059436702728272 out of 0.20732924842834471 sec. for 16 batches.
Forward pass 

In [100]:
reload_and_speed_test_model(gcn_model_ckpt, data, model_dir=model_dir1, device = 'cuda')

INFO:globals_MLP_projector: No inverse normalization for outputs is used.
INFO:levels_MLP_projector: No inverse normalization for outputs is used.
INFO:layers_MLP_projector: No inverse normalization for outputs is used.
INFO:GCN_Readout_MLP: No inverse normalization for outputs is used.


Forward pass needed 2.1776693878173825 out of 2.190135307312011 sec. for 16 batches.
Forward pass needed 2.0077733459472658 out of 2.0186623992919923 sec. for 16 batches.
Forward pass needed 2.011904182434082 out of 2.022813705444336 sec. for 16 batches.
Forward pass needed 2.011231727600098 out of 2.022241271972656 sec. for 16 batches.
Forward pass needed 2.0112457733154296 out of 2.022254585266113 sec. for 16 batches.
Forward pass needed 2.0136365661621096 out of 2.024816650390625 sec. for 16 batches.
Forward pass needed 2.043749099731445 out of 2.0566937713623044 sec. for 16 batches.
Forward pass needed 2.015239624023437 out of 2.026893302917481 sec. for 16 batches.
Forward pass needed 2.0128231201171873 out of 2.023723014831543 sec. for 16 batches.
Forward pass needed 2.00696110534668 out of 2.0178728866577154 sec. for 16 batches.
Forward pass needed 2.007321876525879 out of 2.0183091278076173 sec. for 16 batches.
Forward pass needed 2.0186029052734376 out of 2.030065673828125 sec.

# Pristine, GPU, 2048 batch-size

In [18]:
batch_size = 2048
data_b2048 = get_one_snapshot(h5_path, batch_size, exp_type='pristine')

4 batches of size 2048, amounting to 8192 data points.


In [28]:
reload_and_speed_test_model(gn_model_ckpt, data_b2048, model_dir=model_dir1, device = 'cuda')

Forward pass needed 0.7218496704101562 out of 0.7399539642333984 sec. for 4 batches.
Forward pass needed 0.6532830352783203 out of 0.6707641906738281 sec. for 4 batches.
Forward pass needed 0.6475794525146484 out of 0.6647284393310547 sec. for 4 batches.
Forward pass needed 0.6557327117919922 out of 0.6731878509521485 sec. for 4 batches.
Forward pass needed 0.6496030731201171 out of 0.66685693359375 sec. for 4 batches.
Forward pass needed 0.6537760009765625 out of 0.670913558959961 sec. for 4 batches.
Forward pass needed 0.6516338348388673 out of 0.6695718383789062 sec. for 4 batches.
Forward pass needed 0.6538504333496094 out of 0.6704374694824219 sec. for 4 batches.
Forward pass needed 0.6548609466552735 out of 0.6717781829833984 sec. for 4 batches.
Forward pass needed 0.6531750183105469 out of 0.6698208618164062 sec. for 4 batches.
Forward pass needed 0.6563542175292969 out of 0.6726766815185546 sec. for 4 batches.
Forward pass needed 0.656718505859375 out of 0.6729977569580077 sec.

In [29]:
reload_and_speed_test_model(gcn_model_ckpt, data_b2048, model_dir=model_dir1, device = 'cuda')

Forward pass needed 3.7061569824218754 out of 3.714858642578125 sec. for 4 batches.
Forward pass needed 2.896785217285156 out of 2.905533264160156 sec. for 4 batches.
Forward pass needed 2.8897801513671877 out of 2.8985588378906253 sec. for 4 batches.
Forward pass needed 2.8932310180664063 out of 2.9020509033203123 sec. for 4 batches.
Forward pass needed 2.8897975463867187 out of 2.8986334228515624 sec. for 4 batches.
Forward pass needed 2.9021739501953125 out of 2.911158630371094 sec. for 4 batches.
Forward pass needed 2.8977039184570312 out of 2.9065388793945313 sec. for 4 batches.
Forward pass needed 2.9006817626953123 out of 2.9097777709960937 sec. for 4 batches.
Forward pass needed 2.915938903808594 out of 2.92475146484375 sec. for 4 batches.
Forward pass needed 2.897484924316406 out of 2.9064420166015625 sec. for 4 batches.
Forward pass needed 2.914097717285156 out of 2.9229306030273436 sec. for 4 batches.
Forward pass needed 2.8998575439453127 out of 2.9086227416992188 sec. for 

In [29]:
reload_and_speed_test_model(cnn_model_ckpt, data_b2048, model_dir=model_dir1, device = 'cuda')

Forward pass needed 0.01688489627838135 out of 0.1570764808654785 sec. for 4 batches.
Forward pass needed 0.017217887878417968 out of 0.15595417785644533 sec. for 4 batches.
Forward pass needed 0.017074079990386962 out of 0.15110553359985351 sec. for 4 batches.
Forward pass needed 0.01698102378845215 out of 0.1521530876159668 sec. for 4 batches.
Forward pass needed 0.01702668809890747 out of 0.1574676475524902 sec. for 4 batches.
Forward pass needed 0.01701107215881348 out of 0.1590118408203125 sec. for 4 batches.
Forward pass needed 0.017023072242736817 out of 0.17266381072998047 sec. for 4 batches.
Forward pass needed 0.01707263994216919 out of 0.16868659210205078 sec. for 4 batches.
Forward pass needed 0.017040223598480224 out of 0.17019494628906248 sec. for 4 batches.
Forward pass needed 0.016937184333801272 out of 0.1638901786804199 sec. for 4 batches.
Forward pass needed 0.017051104068756102 out of 0.16162303924560545 sec. for 4 batches.
Forward pass needed 0.017090943813323975 o

# Pristine, GPU, 8192 batch-size

In [40]:
batch_size = 8192
data_b8192 = get_one_snapshot(h5_path, batch_size, exp_type='pristine')

1 batches of size 8192, amounting to 8192 data points.


In [41]:
reload_and_speed_test_model(gn_model_ckpt, data_b8192, model_dir=model_dir1, device = 'cuda')

Forward pass needed 0.7889976196289062 out of 0.8039715576171875 sec. for 1 batches.
Forward pass needed 0.6793284912109375 out of 0.694761474609375 sec. for 1 batches.
Forward pass needed 0.6762897338867188 out of 0.6920478515625 sec. for 1 batches.
Forward pass needed 0.6800123901367188 out of 0.6956384887695313 sec. for 1 batches.
Forward pass needed 0.6851010131835937 out of 0.7004994506835938 sec. for 1 batches.
Forward pass needed 0.6800621948242187 out of 0.6965616455078125 sec. for 1 batches.
Forward pass needed 0.6897576904296875 out of 0.7064658203125 sec. for 1 batches.
Forward pass needed 0.6823429565429687 out of 0.6980708618164062 sec. for 1 batches.
Forward pass needed 0.6887998657226563 out of 0.7045361938476562 sec. for 1 batches.
Forward pass needed 0.6839943237304688 out of 0.6999013061523438 sec. for 1 batches.
Forward pass needed 0.6908799438476563 out of 0.7066085205078125 sec. for 1 batches.
Forward pass needed 0.6837371826171875 out of 0.6998486938476562 sec. fo

In [42]:
reload_and_speed_test_model(gcn_model_ckpt, data_b8192, model_dir=model_dir1, device = 'cuda')

Forward pass needed 3.03328662109375 out of 3.04058740234375 sec. for 1 batches.
Forward pass needed 2.874069580078125 out of 2.881867431640625 sec. for 1 batches.
Forward pass needed 2.8738310546875 out of 2.881159423828125 sec. for 1 batches.
Forward pass needed 2.881129638671875 out of 2.888446044921875 sec. for 1 batches.
Forward pass needed 2.875958251953125 out of 2.8832568359375 sec. for 1 batches.
Forward pass needed 2.87265673828125 out of 2.880024658203125 sec. for 1 batches.
Forward pass needed 2.880484130859375 out of 2.8879716796875 sec. for 1 batches.
Forward pass needed 2.879921875 out of 2.887448974609375 sec. for 1 batches.
Forward pass needed 2.883208984375 out of 2.89071923828125 sec. for 1 batches.
Forward pass needed 2.882713134765625 out of 2.890021240234375 sec. for 1 batches.
Forward pass needed 2.88280126953125 out of 2.89028515625 sec. for 1 batches.
Forward pass needed 2.879398681640625 out of 2.88668408203125 sec. for 1 batches.
Forward time for GCN+Readout 

In [43]:
reload_and_speed_test_model(cnn_model_ckpt, data_b8192, model_dir=model_dir1, device = 'cuda')

Forward pass needed 0.8122559204101563 out of 0.883565185546875 sec. for 1 batches.
Forward pass needed 0.38626947021484376 out of 0.45689306640625 sec. for 1 batches.
Forward pass needed 0.38705331420898437 out of 0.45968405151367187 sec. for 1 batches.
Forward pass needed 0.3892257385253906 out of 0.46108642578125 sec. for 1 batches.
Forward pass needed 0.38709637451171874 out of 0.45921588134765623 sec. for 1 batches.
Forward pass needed 0.38807293701171874 out of 0.4608348083496094 sec. for 1 batches.
Forward pass needed 0.38875082397460936 out of 0.4611448669433594 sec. for 1 batches.
Forward pass needed 0.3900560302734375 out of 0.463880126953125 sec. for 1 batches.
Forward pass needed 0.389434814453125 out of 0.46273739624023436 sec. for 1 batches.
Forward pass needed 0.39090365600585936 out of 0.46180966186523437 sec. for 1 batches.
Forward pass needed 0.39091598510742187 out of 0.46190383911132815 sec. for 1 batches.
Forward pass needed 0.3902356872558594 out of 0.461490173339

In [44]:
reload_and_speed_test_model(mlp_model_ckpt, data_b8192, model_dir=model_dir1, device = 'cuda')

Forward pass needed 0.006568096160888672 out of 0.014071840286254883 sec. for 1 batches.
Forward pass needed 0.006566400051116943 out of 0.014201760292053223 sec. for 1 batches.
Forward pass needed 0.006565567970275879 out of 0.01394041633605957 sec. for 1 batches.
Forward pass needed 0.00656713581085205 out of 0.013878527641296387 sec. for 1 batches.
Forward pass needed 0.006570240020751953 out of 0.013969887733459472 sec. for 1 batches.
Forward pass needed 0.0065640959739685055 out of 0.013800736427307128 sec. for 1 batches.
Forward pass needed 0.006565663814544678 out of 0.0136014404296875 sec. for 1 batches.
Forward pass needed 0.006569151878356934 out of 0.01374556827545166 sec. for 1 batches.
Forward pass needed 0.0065718722343444825 out of 0.013707103729248047 sec. for 1 batches.
Forward pass needed 0.006565951824188233 out of 0.013614399909973145 sec. for 1 batches.
Forward pass needed 0.006568511962890625 out of 0.013701696395874024 sec. for 1 batches.
Forward pass needed 0.00

# Clear-sky

In [80]:
gn_model_ckpt_cs = "0.3491valMAE_97ep_GN+READOUT_CS_1990+1999+2003train_2005val_Z_7seed_12h58m_on_Aug_25_3emesh6i"
gcn_model_ckpt_cs = "0.5462valMAE_189ep_GCN+READOUT_CS_1990+1999+2003train_2005val_Z_7seed_03h09m_on_Aug_23_1gdj0tl7"

In [83]:
batch_size = 512
data_cs = get_one_snapshot(h5_path, batch_size, exp_type='clear_sky')

16 batches of size 512, amounting to 8192 data points.


In [84]:
reload_and_speed_test_model(gn_model_ckpt_cs, data_cs, model_dir=model_dir1, device = 'cuda')

INFO:GN_readout_MLP: No inverse normalization for outputs is used.


_------------------------- True
Forward pass needed 0.23414691162109375 out of 0.2700646438598633 sec. for 16 batches.
Forward pass needed 0.20768921661376955 out of 0.23975526523590088 sec. for 16 batches.
Forward pass needed 0.19524076747894287 out of 0.23038259220123292 sec. for 16 batches.
Forward pass needed 0.19199423980712887 out of 0.22427030277252197 sec. for 16 batches.
Forward pass needed 0.1913643503189087 out of 0.2247454414367676 sec. for 16 batches.
Forward pass needed 0.19221068668365482 out of 0.22442188835144045 sec. for 16 batches.
Forward pass needed 0.19364995193481446 out of 0.2267863044738769 sec. for 16 batches.
Forward pass needed 0.19259417629241937 out of 0.2246492147445679 sec. for 16 batches.
Forward pass needed 0.1904752330780029 out of 0.22599372768402098 sec. for 16 batches.
Forward pass needed 0.19192265605926515 out of 0.2244075508117676 sec. for 16 batches.
Forward pass needed 0.19205471992492676 out of 0.2242693452835083 sec. for 16 batches.
Forward 

In [106]:
reload_and_speed_test_model(gcn_model_ckpt_cs, data_cs, model_dir=model_dir1, device = 'cuda')

INFO:globals_MLP_projector: No inverse normalization for outputs is used.
INFO:levels_MLP_projector: No inverse normalization for outputs is used.
INFO:layers_MLP_projector: No inverse normalization for outputs is used.
INFO:GCN_Readout_MLP: No inverse normalization for outputs is used.


Forward pass needed 2.972647628784179 out of 2.9919323577880865 sec. for 16 batches.
Forward pass needed 2.8854375152587894 out of 2.9033256988525387 sec. for 16 batches.
Forward pass needed 2.884224716186523 out of 2.901990356445312 sec. for 16 batches.
Forward pass needed 2.884748825073242 out of 2.90565121459961 sec. for 16 batches.
Forward pass needed 2.884719268798828 out of 2.9026273193359375 sec. for 16 batches.
Forward pass needed 2.8858316345214847 out of 2.9038018798828125 sec. for 16 batches.
Forward pass needed 2.886290420532226 out of 2.904282043457031 sec. for 16 batches.
Forward pass needed 2.8837029724121095 out of 2.901633026123047 sec. for 16 batches.
Forward pass needed 2.890815261840821 out of 2.908788726806641 sec. for 16 batches.
Forward pass needed 2.8844036407470703 out of 2.9022075042724613 sec. for 16 batches.
Forward pass needed 2.885490646362305 out of 2.903402496337891 sec. for 16 batches.
Forward pass needed 2.897818984985351 out of 2.9158461608886714 sec.