In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parents[1]  # or resolve explicitly
sys.path.insert(0, str(PROJECT_ROOT))
print(PROJECT_ROOT)

In [None]:
# requires to install eofs and gpytorch
import xarray as xr
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
#import gpytorch
import os
import glob
#from eofs.xarray import Eof

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from typing import Dict, Optional, List, Callable, Tuple, Union

#import wandb
#from sklearn.model_selection import train_test_split, cross_val_score
#from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

In [None]:
def compute_per_channel_statistics(experiment_dict, gas_vars, log_transform=False, plot=False):
    """
    Compute mean and std for each channel separately.
    Returns dict with stats for each variable.
    """
    stats = {}
    
    for var in gas_vars:
        all_values = []
        for exp in experiment_dict.keys():
            data = experiment_dict[exp][var]
            # Log transform for gases (very small values)
            
            if var in ['BC', 'CH4', 'SO2', 'CO2'] and log_transform:
                data_log = np.log10(np.clip(data, 1e-15, None) + 1e-15)
                all_values.append(data_log.flatten())
            else:
                all_values.append(data.flatten()) # No log for temperature/other outputs
        
        all_values = np.concatenate(all_values)
        stats[var] = {
            'mean': np.mean(all_values),
            'std': np.std(all_values),
            'use_log': var in ['BC', 'CH4', 'SO2', 'CO2']
        }
        print(f"{var}: mean={stats[var]['mean']:.6f}, std={stats[var]['std']:.6f}, log={stats[var]['use_log']}")
        if plot:
            plt.title(f"{exp}: {var}")
            plt.hist(all_values, bins=50)
            plt.show()
            if var in ['BC', 'CH4', 'SO2', 'CO2'] and not log_transform:
                plt.title(f"{exp}: {var} log-transformed")
                plt.hist(np.log10(np.clip(all_values, 1e-15, None) + 1e-15), bins=50)
                plt.show()
    
            z_norm = (all_values - np.mean(all_values)) / np.std(all_values)
            plt.title(f"{exp}: {var} z-normed")
            plt.hist(z_norm, bins=50)
            plt.show()

    return stats

def compute_per_channel_statistics_linear_z(
    experiment_dict,
    gas_vars,
    plot=False,
):
    stats = {}

    for var in gas_vars:
        all_values = []

        for exp in experiment_dict.keys():
            data = experiment_dict[exp][var]
            all_values.append(data.flatten())

        all_values = np.concatenate(all_values)

        mean = np.mean(all_values)
        std = np.std(all_values)

        stats[var] = {
            "mean": mean,
            "std": std,
            "use_log": False
        }

        print(f"{var}: mean={mean:.6e}, std={std:.6e}, log=False")
        if plot:
            # Raw distribution
            plt.figure()
            plt.title(f"{var} raw (linear)")
            plt.hist(all_values, bins=50)
            plt.show()
    
            # Z-scored distribution
            z = (all_values - mean) / std
            plt.figure()
            plt.title(f"{var} z-score (linear)")
            plt.hist(z, bins=50)
            plt.show()

    return stats

In [None]:
import healpix
def interpolate_dh_to_hp(nside, variable: xr.DataArray):
    """
    Input is xr.DataArray 
    """
    
    npix = healpix.nside2npix(nside)
    hlong, hlat = healpix.pix2ang(nside, np.arange(0, npix, 1), lonlat=True, nest=True)
    hlong = np.mod(hlong, 360)
    xlong = xr.DataArray(hlong, dims="z")
    xlat = xr.DataArray(hlat, dims="z")

    xhp = variable.interp(lat=xlat, lon=xlong, kwargs={"fill_value": None})
    hp_image = np.array(xhp.to_numpy(), dtype=np.float32) # ! removed to_array()
    return hp_image

def interpolate_dh_to_hp_output(nside, variable: xr.DataArray):
    """
    Input is xr.DataArray 
    """
    
    npix = healpix.nside2npix(nside)
    hlong, hlat = healpix.pix2ang(nside, np.arange(0, npix, 1), lonlat=True, nest=True)
    hlong = np.mod(hlong, 360)
    xlong = xr.DataArray(hlong, dims="z")
    xlat = xr.DataArray(hlat, dims="z")

    xhp = variable.interp(y=xlat, x=xlong, kwargs={"fill_value": None})
    hp_image = np.array(xhp.to_numpy(), dtype=np.float32) # ! removed to_array()
    return hp_image

def e5_to_numpy_hp(e5xr, nside: int, normalized: bool):
    """
    Input is class with xr.DataArray class variables
    """

    hp_surface = interpolate_dh_to_hp(nside, e5xr.surface)
    hp_upper = interpolate_dh_to_hp(nside, e5xr.upper)

    if normalized:
        stats = deserialize_dataset_statistics(nside)
        hp_surface, hp_upper = normalize_sample(stats.item(), hp_surface, hp_upper)

    return hp_surface, hp_upper

def get_input_paths(exp, input_dir, fire_type):
    input_gasses = {
    "BC":  "BC_sum",
    "CH4": "CH4_sum",
    "SO2": "SO2_sum",
    "CO2": "CO2_sum",
    }
    gas_files = {g: [] for g in gas_patterns}
    for gas, folder_name in input_gasses.items():
        var_dir = os.path.join(input_dir, exp, folder_name, "250_km", "mon")
        files = glob.glob(var_dir + "/**/*.nc", recursive=True)
    
        for f in files:
            if folder_name in f and (gas == "CO2" or fire_type in f): # CO2 does not have fire_type
                gas_files[gas].append(f)
    #print(gas_files["BC"])
    # Check same len
    for k, v in gas_files.items():
        print(k, len(v))

    return gas_files

def get_output_paths(exp, target_dir, mod, ensembles, variables):
    var_file_dict = {v: [] for v in variables}
    for var in variables:
        var_dir = os.path.join(target_dir, mod, ensembles, exp, var, '250_km', 'mon')
        var_files = glob.glob(var_dir + '/**/*.nc', recursive=True)
        #var_file_dict[var].append(var_files)
        for f in var_files:
            var_file_dict[var].append(f)

    return var_file_dict

def get_hp_dataset(files : dict, nside : int, output : bool = False):
    hp_gas_dict = {}
    #print(files)
    for var, var_files in files.items():

        
        # ds = xr.open_mfdataset(var_files, concat_dim="time", combine="nested")
        # this required dask, which is not currently in constructed apptainer

        datasets = [xr.open_dataset(f) for f in var_files]
        ds = xr.concat(datasets, dim="time").sortby("time")
        
        ds = ds.sortby("time")
        arr = ds.to_array().squeeze("variable") # remove var dim since 1

        if output:
            hp = interpolate_dh_to_hp_output(nside, arr)
        else:
            hp = interpolate_dh_to_hp(nside, arr)
        # print(type(hp))
        hp_gas_dict[var] = hp

    return hp_gas_dict

import healpy
def healpix_plotting(hp_data, data_shift=0, norm=None, nest=True):
    data = hp_data + data_shift
    healpy.visufunc.cartview(data, nest=nest, norm=norm)
    #healpy.visufunc.orthview(data, nest=nest, norm=norm)
    #healpy.visufunc.mollview(data, nest=nest, norm=norm)
    #healpy.visufunc.gnomview(data, nest=nest, norm=norm)

In [None]:
datapath = "/proj/heal_pangu/users/x_tagty/climateset"
input_dir = os.path.join(datapath, "inputs", "input4mips")
target_dir = os.path.join(datapath, "outputs", "CMIP6")

fire_type = 'all-fires'
output_vars = ['tas', "pr"]

mod = 'CAS-ESM2-0'
ensembles = 'r3i1p1f1'
experiments = ["ssp585", "ssp126", "ssp370"]
#experiments = [train_experiments[0]]
#input_gases = ['BC_sum', 'CH4_sum', 'CO2_sum', 'SO2_sum']
gas_patterns = {
    "BC":  "BC_sum",
    "CH4": "CH4_sum",
    "SO2": "SO2_sum",
    "CO2": "CO2_sum",
}

nside = 32
experiment_input_dict = {}
experiment_output_dict = {}
for exp in experiments:
    input_paths = get_input_paths(exp, input_dir, fire_type)
    print("got input paths")
    hp_input_dict = get_hp_dataset(input_paths, nside)
    print("hp input done")
    # TODO : Add concatenate or similar if multiple experiments
    experiment_input_dict[exp] = hp_input_dict
    
    print("checking outputs")
    output_paths = get_output_paths(exp, target_dir, mod, ensembles, output_vars)
    print("got output paths")
    hp_target_dict = get_hp_dataset(output_paths, nside, output=True)
    # TODO : Add concatenate or similar  if multiple experiments
    experiment_output_dict[exp] = hp_target_dict

In [None]:
def prepare_manual_batches(experiment_input_dict, experiment_output_dict, 
                           gas_vars=['BC', 'CH4', 'SO2', 'CO2'],
                           output_vars=["tas", "pr"],
                           batch_size=8):
    """
    Manually create batches without using DataLoader.
    Returns a list of (input_batch, output_batch) tuples.
    """
    all_inputs = []
    all_outputs = []
    all_metadata = []
    print(experiment_output_dict)
    # Collect all samples
    for exp in experiment_input_dict.keys():
        n_timesteps = experiment_input_dict[exp][gas_vars[0]].shape[0]
        
        for t in range(n_timesteps):
            # Stack input channels
            input_channels = [experiment_input_dict[exp][var][t] for var in gas_vars]
            input_data = np.stack(input_channels, axis=0)  # (C, N)
            all_inputs.append(input_data)
            
            # Stack output channels (if doing supervised learning)
            if experiment_output_dict and exp in experiment_output_dict:
                output_channels = [experiment_output_dict[exp][var][t] for var in output_vars]
                output_data = np.stack(output_channels, axis=0)
                all_outputs.append(output_data)
            
            all_metadata.append({'exp': exp, 'time': t})
    
    # Convert to numpy arrays
    all_inputs = np.array(all_inputs)  # (N_samples, C, N_pixels)
    if all_outputs:
        all_outputs = np.array(all_outputs)
    
    # Create batches
    n_samples = len(all_inputs)
    n_batches = (n_samples + batch_size - 1) // batch_size
    
    batches = []
    for i in range(n_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, n_samples)
        
        batch_input = all_inputs[start_idx:end_idx]
        batch_output = all_outputs[start_idx:end_idx] # if all_outputs else None
        batch_meta = all_metadata[start_idx:end_idx]
        
        batches.append({
            'input': batch_input,
            'output': batch_output,
            'metadata': batch_meta,
            'batch_idx': i
        })
    
    return batches, all_inputs, all_outputs


def apply_log_transform(data, global_mean=None, global_std=None, eps=1e-12):
    """Apply log10 transform and normalization"""
    data_log = np.log10(np.clip(data, 1e-15, None) + 1e-15)
    
    if global_mean is not None and global_std is not None:
        data_scaled = (data_log - global_mean) / max(global_std, eps)
    else:
        data_scaled = (data_log - data_log.mean()) / max(data_log.std(), eps)
    
    return data_scaled

In [None]:
input_vars = ['BC', 'CH4', 'SO2', 'CO2']
output_vars = ["tas", "pr"]
batch_size = 64

# 1. Prepare batches manually
print("Preparing batches...")
batches, all_inputs, all_outputs = prepare_manual_batches(
    experiment_input_dict, 
    experiment_output_dict,
    gas_vars= input_vars,
    output_vars = output_vars,
    batch_size=batch_size
)