In [15]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.4"

from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import jax
import corrector_src.loss.sgs_turb_loss as loss
from functools import partial
import numpy as np
from corrector_src.data.dataset import dataset


In [16]:
with initialize(version_base=None, config_path="../../configs"):
    cfg = compose(config_name="config")
    print(cfg)


{'experiment': {'name': 'turbulence_force_corrector'}, 'data': {'debug': '${training.debug}', 'hr_res': 160, 'downscaling_factor': 2, 'num_checkpoints': 10, 'num_timesteps': 1000, 'generate_data_on_fly': True, 'precomputed_data': False, 'fixed_timestep': False, 'dt_max': 0.1, 't_end': 1.0, 'snapshot_timepoints': [1.0], 'num_snapshots': 1, 'use_specific_snapshot_timepoints': True, 'return_snapshots': True, 'scenarios': [1]}, 'training': {'epochs': 300, 'n_look_behind': 10, 'learning_rate': 0.001, 'return_full_sim': False, 'return_full_sim_epoch_interval': 10, 'rng_key': 60, 'debug': False, 'mse_loss': 1, 'spectral_energy_loss': 0.01, 'rate_of_strain_loss': 0.01, 'early_stopping': True}, 'models': {'_target_': 'corrector_src.model.fno_hd_force_corrector.TurbulenceSGSForceCorrectorFNO', '_name_': 'fno', 'hidden_channels': 16, 'n_fourier_layers': 1, 'fourier_modes': 10, 'shifting_modes': 2}}


In [17]:
def make_loss_function(cfg_training):
    """Builds a pure JAX-compatible loss function using values from cfg_training."""

    # Define loss components as a dict of {name: (weight, fn)}
    loss_fns = {
        "mse": (
            cfg_training["mse_loss"],
            lambda pred, gt, config, registered_vars, params: loss.mse_loss(
                pred, gt
            ),
        ),
        "strain": (
            cfg_training["rate_of_strain_loss"],
            lambda pred,
            gt,
            config,
            registered_vars,
            params: loss.rate_of_strain_loss(pred, gt, config, registered_vars),
        ),
        # "spectral": (
        #     cfg_training["spectral_energy_loss"],
        #     lambda pred,
        #     gt,
        #     config,
        #     registered_vars,
        #     params: loss.spectral_energy_loss(
        #         pred, gt, config, registered_vars, params
        #     ),
        # ),
    }
    active_loss_indices = {
        i: name.replace("_loss", "")
        for i, (name, (w, _)) in enumerate(loss_fns.items())
        if w != 0
    }
    active_weights = {
        i: w
        for i, (name, (w, _)) in enumerate(loss_fns.items())
        if w != 0
    }

    def compute_loss_from_components(loss_components):
        #need to make this for more than 1 loss lol
        print(loss_components.shape)
        if len(loss_components.shape) == 1:
            total_loss = 0.0
            for i, weight in active_weights.items():
                total_loss += loss_components[i] * weight

        else:
            total_loss = np.zeros(loss_components.shape[0])
            for j in range(loss_components.shape[0]):
                for i, weight in active_weights.items():
                    total_loss[j] += loss_components[j, i] * weight
        return total_loss
    
    @partial(jax.jit, static_argnames=["config", "registered_variables"])
    def loss_function(
        predicted, ground_truth, config, registered_variables, params
    ):
        total = 0.0
        components = {}

        for name, (weight, fn) in loss_fns.items():
            if weight > 0:
                val = fn(
                    predicted, ground_truth, config, registered_variables, params
                )
                components[name] = val
                total += weight * val

        for name, value in components.items():
            jax.debug.print("{name}: {value}", name=name, value=value)

        return total, components

    return loss_function, compute_loss_from_components, active_loss_indices


In [18]:
dataset_test = dataset([1], cfg.data)
state, config, params, helper, reg_vars, seed = dataset_test.sim_initializator(16)

Returning snapshots with specific snapshots [1.0]


In [19]:
loss_function, compute_loss_from_components, active_loss_indices = make_loss_function(cfg.training)


In [20]:
predicted = jax.numpy.zeros((3, 16, 16, 16)) + 0.1
ground_truth = jax.numpy.zeros((3, 16, 16, 16)) + 0.2

total_loss, components = loss_function(predicted, ground_truth, config, reg_vars, params)

mse: 0.010000000000000005
strain: 0.0


In [21]:
print(total_loss)

0.010000000000000005


In [22]:
print(components)

{'mse': Array(0.01, dtype=float64), 'strain': Array(0., dtype=float64)}
