In [1]:
import hydra
import matplotlib.pyplot as plt
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf
from tqdm import tqdm
import xarray as xr
from utils import get_filesystem
OmegaConf.register_new_resolver("eval", eval)
from graphcast_datapipes import SeqZarrDatapipe_GraphCast
from normalisation_wrapper import InputsAndResiduals

from modulus.distributed import DistributedManager
from graphcast_reordering import *
from loss_weights import WeightedMSELoss, get_weights

In [2]:
def check_alignment(input_list, index_array, intended_output_list):
    """
    Checks if applying an index array to an input list aligns with an intended output list.

    Args:
        input_list: The original list of features (e.g., ORIGINAL_ORDER_INPUTS_178).
        index_array: The array of indices to apply to the input_list.
        intended_output_list: The expected output list after applying the index_array.

    Prints detailed information about mismatches, including feature names and indices.
    Returns True if the lists align, False otherwise.
    """
    # indexed_list = [input_list[i] for i in index_array if i < len(input_list)]  # Apply indexing
    indexed_list = []
    for i in index_array:
        if i is not None and i < len(input_list):
            indexed_list.append(input_list[i])
        else:
            indexed_list.append(('total_precipitation', ('time', 0)))
    if len(indexed_list) != len(intended_output_list):
        print(f"Error: Length mismatch! Indexed list length: {len(indexed_list)}, Intended output length: {len(intended_output_list)}")
        return False

    aligned = True
    for i, (indexed_feature, intended_feature) in enumerate(zip(indexed_list, intended_output_list)):
        if indexed_feature != intended_feature:
            print(f"Mismatch at index {i}:")
            print(f"  Indexed feature: {indexed_feature} (from input index {index_array[i]})")
            print(f"  Intended feature: {intended_feature}")
            aligned = False

    if aligned:
        print("Alignment check passed: The indexed list matches the intended output list.")
        return True
    else:
        print("Alignment check failed: There are mismatches between the indexed list and the intended output list.")
        return False
check_alignment(ORIGINAL_ORDER_INPUTS_176, original_176_to_original_83, ORIGINAL_ORDER_OUTPUTS_83)
check_alignment(ORIGINAL_ORDER_INPUTS_178, original_178_to_original_83, ORIGINAL_ORDER_OUTPUTS_83)
check_alignment(reordered_features_178, reorder_178_to_original_178, ORIGINAL_ORDER_INPUTS_178)
check_alignment(reordered_features_178, reorder_178_to_original_176, ORIGINAL_ORDER_INPUTS_176)
check_alignment(reordered_features_176, reorder_176_to_original_176, ORIGINAL_ORDER_INPUTS_176)
check_alignment(reordered_features_list_outputs_83, reorder_output_to_original_output, ORIGINAL_ORDER_OUTPUTS_83)

Alignment check passed: The indexed list matches the intended output list.
Alignment check passed: The indexed list matches the intended output list.
Alignment check passed: The indexed list matches the intended output list.
Alignment check passed: The indexed list matches the intended output list.
Alignment check passed: The indexed list matches the intended output list.
Alignment check passed: The indexed list matches the intended output list.


True

In [3]:
from modulus.models.graphcast.graph_cast_net import GraphCastNet
DistributedManager.initialize()
dist = DistributedManager()

model = GraphCastNet(input_dim_grid_nodes=184, output_dim_grid_nodes=83)
# model = Module.instantiate(
#         {
#             "__name__": cfg.model.name,
#             "__args__": {
#                 k: tuple(v) if isinstance(v, ListConfig) else v
#                 for k, v in cfg.model.args.items()
#             },  # TODO: maybe mobe this conversion to resolver?
#         }
#     )
model = model.to(dist.device)
model.load('../../../../gc_weights/graphcast_0.25_13.mdlus')

  warn(
  model_dict = torch.load(


In [4]:
data = xr.open_zarr('unified_recipe_datasets/arco_era5.zarr')
hydra.core.global_hydra.GlobalHydra.instance().clear()  # Clear previous hydra instances
hydra.initialize(config_path="conf")  
cfg = hydra.compose(config_name="config") 

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path="conf")


In [5]:
fs = get_filesystem(
        cfg.filesystem.type,
        cfg.filesystem.key,
        cfg.filesystem.endpoint_url,
        cfg.filesystem.region_name,
    )
train_dataset_mapper = fs.get_mapper(cfg.curated_dataset.train_dataset_filename)
val_dataset_mapper = fs.get_mapper(cfg.curated_dataset.val_dataset_filename)

In [6]:
val_datapipe = SeqZarrDatapipe_GraphCast(
        file_mapping=val_dataset_mapper,
        variable_groups=cfg.curated_dataset.variable_groups,
        batch_size=cfg.validation.batch_size,
        num_steps=cfg.validation.num_steps + cfg.training.nr_input_steps,
        shuffle=False,
        device=dist.device,
        process_rank=dist.rank,
        world_size=dist.world_size,
        batch=cfg.datapipe.batch,
        parallel=cfg.datapipe.parallel,
        num_threads=cfg.datapipe.num_threads,
        prefetch_queue_depth=cfg.datapipe.prefetch_queue_depth,
        py_num_workers=cfg.datapipe.py_num_workers,
        py_start_method=cfg.datapipe.py_start_method,
    )


In [7]:
for j, data in tqdm(enumerate(val_datapipe)):
    break

0it [00:00, ?it/s]


In [8]:
constants = data[0]['constants']
inputs_surface = data[0]['inputs_surface']
inputs_pressure_levels = torch.reshape(data[0]['inputs_pressure_levels'], (cfg.validation.batch_size, cfg.validation.num_steps + cfg.training.nr_input_steps, 78, 721, 1440))
forcings = data[0]['forcings'].permute((0, 1, 2, 4, 3))
node_features = data[0]['node_features']

In [9]:
input = (torch.concat((constants[0][0], forcings[0][0], inputs_surface[0][0], 
               inputs_pressure_levels[0][0], forcings[0][1], inputs_surface[0][1], 
               inputs_pressure_levels[0][1]), dim=-3))
first_target = (torch.concat((constants[0][0], forcings[0][1], inputs_surface[0][1], 
               inputs_pressure_levels[0][1], forcings[0][2], inputs_surface[0][2], 
               inputs_pressure_levels[0][2]), dim=-3))


For Model Input:
1. Subtract input mean 
2. Divide by input stddev


To get next weather state: 
1. If target variable was a model input then multiply by the output stddev and add to corresponding input value. 
2. Else, multiply by input stddev and add input mean


To get label for loss comparison from the next weather state to be pre-computed
1. Get normalised input
2. If target variable was a model input then subtract the corresponding input value to return to residual and divide by output stddev
3. If target variable was not an input, subtract input mean and divide by input stddev

In [10]:
input_mean = xr.load_dataset('../../../../mean_by_level.nc')
input_std = xr.load_dataset('../../../../stddev_by_level.nc')
output_std = xr.load_dataset('../../../../diffs_stddev_by_level.nc')

input_mean = input_mean.rename({'total_precipitation_6hr': 'total_precipitation'})
input_std = input_std.rename({'total_precipitation_6hr': 'total_precipitation'})
output_std = output_std.rename({'total_precipitation_6hr': 'total_precipitation'})

In [11]:
levels_by_order = []
variable_weights = {
    "10m_u_component_of_wind": 0.1,
    "10m_v_component_of_wind": 0.1,
    "mean_sea_level_pressure": 0.1,
    "total_precipitation": 0.1,
}
per_variable_weight_mapping = {}
for idx, variable in enumerate(ORIGINAL_ORDER_OUTPUTS_83):
    if isinstance(variable, str):
        name = variable
        levels_by_order.append(None)
    elif len(variable) == 2:
        name, _ = variable
        levels_by_order.append(None)
    else:
        name, _, level = variable
        levels_by_order.append(int(level[1]))
    if name in variable_weights.keys():
        per_variable_weight_mapping[idx] = variable_weights[name]

latitude = xr.open_zarr(cfg.curated_dataset.train_dataset_filename).coords['latitude'].values

loss_weights = get_weights((83, len(latitude), forcings.shape[-1]), latitude, levels_by_order, per_variable_weight_mapping)

In [12]:
model.to('cpu')
optimizer = torch.optim.Adam(model.parameters())
criterion = WeightedMSELoss(loss_weights)

In [13]:
wrapped_model = InputsAndResiduals(model, input_std, input_mean, output_std, 
                                   ORIGINAL_ORDER_INPUTS_176, ORIGINAL_ORDER_OUTPUTS_83, 
                                   reorder_178_to_original_176, original_176_to_original_83, 
                                   reorder_178_to_original_output)

In [14]:
with torch.no_grad():
    outputs = wrapped_model(input.to('cpu'), forcings[0][2].to('cpu'), node_features[0][0].to('cpu'))

In [15]:
with torch.no_grad():
    targets_extracted_outputs = wrapped_model._outputs_from_input_tensor(first_target.to('cpu'), wrapped_model.outputs_from_full_input_order)        
    inputs = input[..., wrapped_model.input_permutation, :, :].to('cpu') # Inputs now has size 176

    norm_actual_residuals = wrapped_model._subtract_input_and_normalize_target(inputs, targets_extracted_outputs)
    norm_predicted_residuals = wrapped_model._subtract_input_and_normalize_target(inputs, outputs.squeeze())
            
    loss = criterion(norm_predicted_residuals.squeeze(), norm_actual_residuals.squeeze())
loss

tensor(0.0015)

In [16]:
input_timestep_2 = (torch.concat((constants[0][0], forcings[0][2], inputs_surface[0][1], 
               inputs_pressure_levels[0][1], forcings[0][2], outputs.to('cuda').squeeze()[original_output_to_reorder_output]), dim=-3))
second_target = (torch.concat((constants[0][0], forcings[0][2], inputs_surface[0][2], 
               inputs_pressure_levels[0][2], forcings[0][3], inputs_surface[0][3], 
               inputs_pressure_levels[0][3]), dim=-3))

In [17]:
with torch.no_grad():
    outputs_2 = wrapped_model(input_timestep_2.to('cpu'), forcings[0][3].to('cpu'), node_features[0][0].to('cpu'))

In [18]:
with torch.no_grad():
    targets_extracted_outputs_2 = wrapped_model._outputs_from_input_tensor(second_target.to('cpu'), wrapped_model.outputs_from_full_input_order)        
    inputs_timestep_2 = input_timestep_2[..., wrapped_model.input_permutation, :, :].to('cpu') # inputs_timestep_2 now has size 176

    norm_actual_residuals_2 = wrapped_model._subtract_input_and_normalize_target(inputs_timestep_2, targets_extracted_outputs_2)
    norm_predicted_residuals_2 = wrapped_model._subtract_input_and_normalize_target(inputs_timestep_2, outputs.squeeze())
            
    loss = criterion(norm_predicted_residuals_2.squeeze(), norm_actual_residuals_2.squeeze())
loss

tensor(0.0023)

In [19]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

def interactive_image_comparison(array1, array2, title_array):
    """
    Create an interactive plot that displays images from two arrays side by side.
    
    Parameters:
    -----------
    array1 : numpy.ndarray
        First array of images with shape (n, height, width) or (n, height, width, channels)
    array2 : numpy.ndarray
        Second array of images with shape (n, height, width) or (n, height, width, channels)
    title_array : list
        List of strings to be used as titles
        
    Returns:
    --------
    None : Displays the interactive plot in the notebook
    """
    # Ensure arrays have the same number of images
    if len(array1) != len(array2) or len(array1) != len(title_array):
        raise ValueError("All input arrays must have the same length")
    
    # Create function to update the plot
    def update_plot(index):
        
        max1 = np.max(array1[index])
        max2 = np.max(array2[index])
        min1 = np.min(array1[index])
        min2 = np.min(array2[index])
        max_val = np.max((max1, max2))
        min_val = np.min((min1, min2))
        
        # Create a figure with proper layout
        fig = plt.figure(figsize=(14, 7))
        
        # Plot first image
        ax1 = plt.subplot(1, 2, 1)
        im1 = ax1.imshow(array1[index], vmin=min_val, vmax=max_val)
        ax1.axis('off')
        
        # Plot second image
        ax2 = plt.subplot(1, 2, 2)
        im2 = ax2.imshow(array2[index], vmin=min_val, vmax=max_val)
        ax2.axis('off')
        
        # Add a colorbar that applies to both images
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
        fig.colorbar(im2, cax=cbar_ax)
        
        # plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust layout to make room for colorbar
        plt.suptitle(title_array[index], fontsize=14)
        plt.show()
    
    # Create widgets
    slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(title_array)-1,
        step=1,
        description='Image:',
        continuous_update=False
    )
    
    dropdown = widgets.Dropdown(
        options=[(title, i) for i, title in enumerate(title_array)],
        value=0,
        description='Select:'
    )
    
    # Create output widget to display the plot
    output = widgets.Output()
    
    # Link the dropdown and slider
    def on_dropdown_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            slider.value = change['new']
    
    def on_slider_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            dropdown.value = change['new']
            with output:
                output.clear_output(wait=True)
                update_plot(change['new'])
    
    dropdown.observe(on_dropdown_change)
    slider.observe(on_slider_change)
    
    # Display initial plot
    with output:
        update_plot(0)
    
    # Display widgets and output
    display(widgets.HBox([slider, dropdown]))
    display(output)

# # Alternative version for matplotlib-only environments (like non-Jupyter environments)
# def interactive_image_comparison_matplotlib(array1, array2, title_array):
#     """
#     Create an interactive matplotlib plot that displays images from two arrays side by side.
    
#     Parameters:
#     -----------
#     array1 : numpy.ndarray
#         First array of images with shape (n, height, width) or (n, height, width, channels)
#     array2 : numpy.ndarray
#         Second array of images with shape (n, height, width) or (n, height, width, channels)
#     title_array : list
#         List of strings to be used as titles
        
#     Returns:
#     --------
#     None : Displays the interactive plot in a matplotlib window
#     """
#     if len(array1) != len(array2) or len(array1) != len(title_array):
#         raise ValueError("All input arrays must have the same length")
    
#     fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
#     plt.subplots_adjust(bottom=0.25)
    
#     # Initial images
#     img1 = ax1.imshow(array1[0])
#     ax1.set_title(f"Image 1: {title_array[0]}")
#     ax1.axis('off')
    
#     img2 = ax2.imshow(array2[0])
#     ax2.set_title(f"Image 2: {title_array[0]}")
#     ax2.axis('off')
    
#     # Add slider
#     ax_slider = plt.axes([0.25, 0.1, 0.65, 0.03])
#     slider = Slider(
#         ax=ax_slider,
#         label='Image Index',
#         valmin=0,
#         valmax=len(array1)-1,
#         valinit=0,
#         valstep=1
#     )
    
#     # Create a custom dropdown-like widget
#     # (matplotlib doesn't have a built-in dropdown, so we're approximating with buttons)
#     class Dropdown(AxesWidget):
#         def __init__(self, ax, labels, active=0):
#             AxesWidget.__init__(self, ax)
#             self.labels = labels
#             self.active = active
#             self.buttons = []
#             self.cnt = 0
            
#             # Create a button for showing/hiding options
#             self.main_button = Button(plt.axes([0.25, 0.05, 0.65, 0.03]), f"Select: {labels[active]}")
#             self.shown = False
#             self.observers = {}
#             self.main_button.on_clicked(self._show_hide)
            
#         def _show_hide(self, event):
#             if self.shown:
#                 for b in self.buttons:
#                     b.ax.set_visible(False)
#                 self.shown = False
#             else:
#                 y_pos = 0.05
#                 for i, label in enumerate(self.labels):
#                     if not self.buttons:
#                         button_ax = plt.axes([0.25, y_pos - 0.04 * (i + 1), 0.65, 0.03])
#                         button = Button(button_ax, label)
#                         self.buttons.append(button)
#                         button.on_clicked(self._make_callback(i))
#                     else:
#                         self.buttons[i].ax.set_visible(True)
#                 self.shown = True
#             plt.draw()
                    
#         def _make_callback(self, index):
#             def callback(event):
#                 if self.active != index:
#                     self.active = index
#                     self.main_button.label.set_text(f"Select: {self.labels[index]}")
#                     for b in self.buttons:
#                         b.ax.set_visible(False)
#                     self.shown = False
#                     # Notify observers
#                     for cid, func in self.observers.items():
#                         func(index)
#                 plt.draw()
#             return callback
            
#         def on_changed(self, func):
#             """Register a callback to receive slider events."""
#             cid = self.cnt
#             self.observers[cid] = func
#             self.cnt += 1
#             return cid
            
#     # Create dropdown
#     ax_dropdown = plt.axes([0.1, 0.025, 0.8, 0.04])  # This is just a placeholder
#     ax_dropdown.set_visible(False)  # Hide the actual axes
#     dropdown = Dropdown(ax_dropdown, title_array)
    
#     # Update function
#     def update(val):
#         index = int(slider.val)
#         img1.set_data(array1[index])
#         img2.set_data(array2[index])
#         ax1.set_title(f"Image 1: {title_array[index]}")
#         ax2.set_title(f"Image 2: {title_array[index]}")
#         fig.canvas.draw_idle()
    
#     def dropdown_update(index):
#         slider.set_val(index)
    
#     slider.on_changed(update)
#     dropdown.on_changed(dropdown_update)
    
#     plt.show()


In [20]:
interactive_image_comparison(targets_extracted_outputs.squeeze().numpy(), outputs.squeeze().cpu().numpy(), ORIGINAL_ORDER_OUTPUTS_83)

HBox(children=(IntSlider(value=0, continuous_update=False, description='Image:', max=82), Dropdown(description…

Output()

In [21]:
interactive_image_comparison(targets_extracted_outputs_2.squeeze().numpy(), outputs_2.squeeze().cpu().numpy(), ORIGINAL_ORDER_OUTPUTS_83)

HBox(children=(IntSlider(value=0, continuous_update=False, description='Image:', max=82), Dropdown(description…

Output()

In [22]:
def unroll(model, constants, inputs, forcings, node_features, num_steps = 1):
    # Get number of steps to unroll
    if forcings.shape[0] < 3:
        raise ValueError("Need forcings at at least 3 different timesteps to make predictions")
    max_steps = forcings.shape[0] - 2
    model_pred_i_minus_1 = inputs[0]
    model_pred_i_0 = inputs[1]
    model_predicted = []
    for i in range(min(num_steps, max_steps)):
        # Create Input
        input = torch.concat((constants, forcings[i], model_pred_i_minus_1.squeeze(), forcings[i+1], model_pred_i_0.squeeze()), dim=0)
        
        # Store Predictions and update next steps for rollout
        model_pred_i_minus_1 = model_pred_i_0
        model_pred_i_0 = model(input, forcings[i+2], node_features)[original_output_to_reorder_output]
        model_predicted.append(model_pred_i_0)

    # Stack predictions
    model_predicted = torch.stack(model_predicted, dim=1)

    return model_predicted

In [23]:
# unrolled_predictions = unroll(wrapped_model, constants.squeeze()[0].cpu(), torch.concat((inputs_surface, inputs_pressure_levels), dim=-3).squeeze().cpu(), 
#        forcings.squeeze().cpu(), node_features.squeeze()[0].cpu(), num_steps=2)

In [24]:
def eval_forward(model, constants, inputs_surface, inputs_pressure_levels, forcings, node_features, num_steps = 1):
    # Forward pass
    combined_inputs = torch.concat((inputs_surface, inputs_pressure_levels), dim=-3)
    with torch.no_grad:
        net_predicted_variables = unroll(model, constants.squeeze()[0].cpu(), inputs.squeeze(), forcings[0].cpu(), node_features[0, 0].cpu(), num_steps=num_steps)

        # l2 loss
        label = combined_inputs[:num_steps][..., reorder_output_to_original_output, :, :]
        loss = (torch.mean(torch.pow(net_predicted_variables - label)))
    return loss, net_predicted_variables