In [None]:
%pip install xarray tornado_helper git+https://github.com/mit-ll/tornet.git huggingface_hub netcdf4 aria2
%pip uninstall tensorflow -y
%pip install --upgrade tensorflow[and-cuda] xarray[io]

In [None]:
import os 
import tensorflow as tf 
from tensorflow.keras.mixed_precision import set_global_policy

TORNET_ROOT = "/data_tornet"
GOES_ROOT   = "/data_goes"
GOES_BANDS = ["CMI_C01", "CMI_C02", "CMI_C03", "CMI_C04"]
TARGET_H, TARGET_W = 64, 64
BATCH_SIZE = 16
AUTOTUNE = tf.data.AUTOTUNE

os.environ['TORNET_ROOT'] = TORNET_ROOT
os.environ['GOES_ROOT'] = GOES_ROOT

gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus, 'GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

strategy = tf.distribute.MirroredStrategy()
set_global_policy("mixed_float16")  # Mixed precision

In [5]:
# Download Files 
from tornado_helper import TorNet, GOES
tornet = TorNet()
tornet.download(2017)

goes = GOES()
goes.download(2017)

Aria Download:   0%|          | 0.00/15.1G [00:00<?, ?B/s]

Aria Download:   0%|          | 0.00/2.38G [00:00<?, ?B/s]

['data_goes/2017/train/2017/NUL_170402_013436_KSJT_1078362n_V2.nc',
 'data_goes/2017/train/2017/NUL_170420_213054_KIND_1078518n_J1.nc',
 'data_goes/2017/train/2017/NUL_170621_143122_KLIX_1079126n_W4.nc',
 'data_goes/2017/train/2017/WRN_170810_190336_KGLD_1079367n_I5.nc',
 'data_goes/2017/train/2017/WRN_170828_205033_KHGX_1079613n_S4.nc',
 'data_goes/2017/train/2017/WRN_170905_184606_KLWX_1079715n_R9.nc',
 'data_goes/2017/train/2017/NUL_170522_191209_KJAX_1078864n_S3.nc',
 'data_goes/2017/train/2017/WRN_170327_055300_KSHV_1078298n_Q0.nc',
 'data_goes/2017/train/2017/NUL_170618_194948_KBUF_695190s_R6.nc',
 'data_goes/2017/train/2017/WRN_170623_010533_KLZK_1079167n_I7.nc',
 'data_goes/2017/train/2017/NUL_170802_175753_KDIX_705285s_Y7.nc',
 'data_goes/2017/train/2017/NUL_170612_145048_KFSD_712377s_P4.nc',
 'data_goes/2017/train/2017/NUL_170814_033528_KUEX_718028s_L4.nc',
 'data_goes/2017/train/2017/WRN_170624_110324_KDIX_1079187n_M6.nc',
 'data_goes/2017/train/2017/NUL_171008_191215_KGSP_1

In [10]:
tc = tornet.catalog()
tc.to_csv(f'{TORNET_ROOT}/catalog.csv')

In [None]:
import os
import numpy as np
import tensorflow as tf
import xarray as xr
from huggingface_hub import hf_hub_download
from tornet.models.keras.layers import CoordConv2D
from tornet.data.loader import query_catalog, TornadoDataLoader
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, GlobalAveragePooling2D, Concatenate, Dense

model_path = hf_hub_download("tornet-ml/tornado_detector_baseline_v1", "tornado_detector_baseline.keras")
base = tf.keras.models.load_model(model_path, custom_objects={"CoordConv2D": CoordConv2D}, compile=False)
base.trainable = False

model_vars = [inp.name.split(":")[0] for inp in base.inputs]

t_paths = query_catalog(TORNET_ROOT, "train", [2017], random_state=42)
filtered_t, filtered_g = [], []
for p in t_paths:
    rel = os.path.relpath(p, TORNET_ROOT)
    year = rel.split(os.sep)[1]
    g_path = os.path.join(GOES_ROOT, year, *rel.split(os.sep))
    if os.path.exists(g_path):
        filtered_t.append(p)
        filtered_g.append(g_path)

labels = np.array([1 if os.path.basename(p).startswith("TOR") else 0 for p in filtered_t], dtype=np.int32)

with xr.open_dataset(filtered_t[0]) as ds:
    available_vars = [v for v in model_vars if v in ds.variables]

# --- tf.data helpers ---
def generator():
    for r_path, g_path, label in zip(filtered_t, filtered_g, labels):
        try:
            with xr.open_dataset(g_path) as ds_g:
                bands = [ds_g[b].values.astype(np.float32) for b in GOES_BANDS if b in ds_g]
                goes = np.stack(bands, axis=-1)
                goes = np.nan_to_num(goes, nan=0.0)
                goes = tf.image.resize(goes, (TARGET_H, TARGET_W)).numpy()
        except Exception:
            continue
        
        radar_loader = TornadoDataLoader([r_path], variables=available_vars, n_frames=1, shuffle=False)
        radar_sample = next(iter(radar_loader))

        inputs = {}
        for v in model_vars:
            if v in radar_sample:
                arr = radar_sample[v]
                if arr.ndim == 2:
                    arr = np.expand_dims(arr, -1)
                if arr.ndim == 4:
                    arr = arr[0]
                arr = tf.image.resize(arr, (TARGET_H, TARGET_W)).numpy()
                arr = np.nan_to_num(arr, nan=0.0)
            else:
                arr = np.zeros((TARGET_H, TARGET_W, 2), dtype=np.float32)
            inputs[v] = arr

        inputs["goes_input"] = goes

        yield inputs, np.float32(label)

input_spec = {v: tf.TensorSpec(shape=(TARGET_H, TARGET_W, 2), dtype=tf.float32) for v in model_vars}
input_spec["goes_input"] = tf.TensorSpec(shape=(TARGET_H, TARGET_W, len(GOES_BANDS)), dtype=tf.float32)

ds = tf.data.Dataset.from_generator(
    generator,
    output_signature=(input_spec, tf.TensorSpec(shape=(), dtype=tf.float32))
).shuffle(len(filtered_t) // 10).batch(BATCH_SIZE).prefetch(AUTOTUNE).cache().repeat()

steps_per_epoch = len(filtered_t) // BATCH_SIZE

# --- Model ---
with strategy.scope():
    goes_input = Input(shape=(TARGET_H, TARGET_W, len(GOES_BANDS)), name="goes_input", dtype=tf.float32)
    goes_features = Conv2D(16, 3, padding="same", activation="relu")(goes_input)
    goes_features = GlobalAveragePooling2D()(goes_features)

    radar_inputs = [Input(shape=(TARGET_H, TARGET_W, 2), name=v) for v in model_vars]
    radar_features = base({v: inp for v, inp in zip(model_vars, radar_inputs)})

    combined = Concatenate()([radar_features, goes_features])
    output = Dense(1, activation="sigmoid", dtype=tf.float32)(combined)

    model = Model(inputs={**{v: inp for v, inp in zip(model_vars, radar_inputs)}, "goes_input": goes_input}, outputs=output)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])

# --- Train ---
model.fit(ds, epochs=10, steps_per_epoch=steps_per_epoch)