In [1]:
# magic reload
%load_ext autoreload
%autoreload 2

In [2]:
import xarray as xr

ds = xr.open_zarr("/mnt/jua-shared-1/jua-bronze-layer/scratch/alex/erax5-toy-2004-2023-180x360.zarr")
ds

In [3]:
ds["air_temperature_2m"].shape

(175320, 180, 360)

In [4]:
from source.data import ToyEra5Dataset

dataset = ToyEra5Dataset(
    zarr_path="/mnt/jua-shared-1/jua-bronze-layer/scratch/alex/erax5-toy-2004-2023-180x360.zarr",
    start_date="2004-01-01",
    end_date="2023-12-31",
    lead_time_set=[1],
    input_variable_names=["air_pressure_at_mean_sea_level", "air_temperature", "air_temperature_2m"],
    output_variable_names=["air_pressure_at_mean_sea_level", "air_temperature", "air_temperature_2m"],
    stats_path="/mnt/jua-shared-1/jua-silver-layer/all_variables_stats_together_renamed.json",
    use_era5_stats=True,
    levels=[100, 500, 1000]
)

In [5]:
dataset.global_stats

{'air_pressure_at_mean_sea_level': {'mean': 100955.2435763889,
  'std': 1336.6274441189237},
 'air_temperature': {'mean': {'50000.0': 252.9856492784288,
   '22500.0': 220.2857720269097,
   '35000.0': 235.71905958387586,
   '2000.0': 220.06371222601996,
   '10000.0': 208.3711151123047,
   '100000.0': 281.05535820855033,
   '700.0': 231.8156511094835,
   '70000.0': 267.42711249457466,
   '12500.0': 210.56849161783853,
   '40000.0': 242.18450249565973,
   '65000.0': 264.4604749891493,
   '80000.0': 272.581399875217,
   '300.0': 249.90770534939236,
   '17500.0': 215.83446689181858,
   '95000.0': 278.4341349283854,
   '100.0': 261.1414367675781,
   '77500.0': 271.4174106174045,
   '75000.0': 270.1652784559462,
   '97500.0': 279.63818901909724,
   '200.0': 257.1461300320095,
   '1000.0': 226.9903550889757,
   '85000.0': 274.5937438964844,
   '15000.0': 213.30576714409722,
   '7000.0': 209.44116583930122,
   '500.0': 238.413719007704,
   '45000.0': 247.94634229871963,
   '60000.0': 261.169144

In [6]:
# Get all input variable names in order
input_vars = dataset.get_input_variable_names()
print("Input variables:", input_vars)
# Example output: ['air_temperature_2m', 'geopotential_500', 'geopotential_850']

# Get all output variable names in order
output_vars = dataset.get_output_variable_names()
print("Output variables:", output_vars)

Input variables: ['air_pressure_at_mean_sea_level', 'air_temperature_100', 'air_temperature_500', 'air_temperature_1000', 'air_temperature_2m']
Output variables: ['air_pressure_at_mean_sea_level', 'air_temperature_100', 'air_temperature_500', 'air_temperature_1000', 'air_temperature_2m']


In [7]:
batch = dataset[0]

In [52]:
# Test the config
from source.config import TrainingConfig
from pathlib import Path

config = TrainingConfig.from_yaml(Path("configs/small_model.yaml"))
config

TrainingConfig(dry=False, num_epochs=100, num_warmup_epochs=5, warmup_steps=2000, enable_compilation=False, skip_training=False, log_to_wandb=True, log_validation_metrics=True, optimizer_type=<OptimizerType.ADAMW: 'adamw'>, learning_rate=0.0001, min_learning_rate=1e-05, optimizer_betas=(0.9, 0.95), optimizer_epsilon=1e-08, optimizer_weight_decay=0.1, optimizer_fused=True, gradient_accumulation_steps=1, use_lat_weights=False, report_sample_interval=5, decoder_depth=2, img_size=(180, 360), patch_size=4, embed_dim=256, depth=8, num_heads=8, mlp_ratio=4.0, lead_time_set=[6], drop_path=0.1, drop_rate=0.1, batch_size=1, num_workers=8, shuffle=True, pin_memory=True, prefetch_factor=2, persistent_workers=True, multiprocessing_context='spawn', datasets_weight=None, stats_path='/mnt/jua-shared-1/jua-silver-layer/all_variables_stats_together.json', zarr_path='/mnt/jua-shared-1/jua-bronze-layer/scratch/alex/erax5-toy-2004-2023-180x360.zarr', start_date='1979-01-01', end_date='2019-12-31', levels=[

In [9]:
from torch.utils.data import DataLoader as TorchDataLoader
training_dataloader = TorchDataLoader(
    dataset=dataset, 
    batch_size=config.batch_size,
    num_workers=0,
)
training_iterator = iter(training_dataloader)
batch = next(training_iterator)

In [10]:
# Test the model
from source.model import ClimaX
model = ClimaX(
    input_vars=dataset.get_input_variable_names(),
    output_vars=dataset.get_output_variable_names(),
    img_size=config.img_size,
    patch_size=config.patch_size,
    embed_dim=config.embed_dim,
    depth=config.depth,
    decoder_depth=config.decoder_depth,
    num_heads=config.num_heads,
    mlp_ratio=config.mlp_ratio,
    drop_path=config.drop_path,
    drop_rate=config.drop_rate,
)

output = model(
    x=batch["input"],
    lead_times=batch["lead_time"]
)

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
# instantiate accelerator
from source.trainer import get_accelerator
accelerator = get_accelerator(config.gradient_accumulation_steps)


In [54]:
# initialize wandb run
from source.trainer import initialize_wandb_run
initialize_wandb_run(
    run_id="test",
    config=config,
    group="test",
    num_processes=1
)

[34m[1mwandb[0m: Currently logged in as: [33mmodel-engineering-team[0m ([33mjua[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [None]:
# create a trainer
from source.trainer import Trainer
trainer = Trainer(
    config=config,
    accelerator=accelerator,
    run_id="test"
)
trainer.train()

  0%|          | 0/140227 [13:50<?, ?timestep/s]
  0%|          | 0/140227 [11:31<?, ?timestep/s]
  0%|          | 0/140227 [10:31<?, ?timestep/s]
  0%|          | 0/140227 [09:15<?, ?timestep/s]
  0%|          | 0/140227 [05:56<?, ?timestep/s]
  0%|          | 0/140227 [04:19<?, ?timestep/s]
  0%|          | 0/140227 [03:04<?, ?timestep/s]
  0%|          | 0/140227 [02:36<?, ?timestep/s]
  0%|          | 5/140227 [02:08<1003:26:56, 25.76s/timestep]
  1%|          | 759/140227 [02:00<5:54:45,  6.55timestep/s]

KeyboardInterrupt: 

  1%|          | 760/140227 [02:15<5:54:45,  6.55timestep/s]