In [None]:
%pip install xarray tornado_helper git+https://github.com/mit-ll/tornet.git huggingface_hub netcdf4 aria2
%pip uninstall tensorflow -y
%pip install --upgrade tensorflow[and-cuda] xarray[io]

In [None]:
import os 
import tensorflow as tf 
from tensorflow.keras.mixed_precision import set_global_policy

TORNET_ROOT = "/data_tornet"
GOES_ROOT   = "/data_goes"
GOES_BANDS = ["CMI_C01", "CMI_C02", "CMI_C03", "CMI_C04", "CMI_C07", "CMI_C08", "CMI_C13", "CMI_C15"]
TARGET_H, TARGET_W = 64, 64
BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE
YEARS = [2017, 2018, 2019, 2020, 2021, 2022]
TF_ROOT = "/data_combined"
NUM_THREADS = 16

os.environ['TORNET_ROOT'] = TORNET_ROOT
os.environ['GOES_ROOT'] = GOES_ROOT

gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus, 'GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

strategy = tf.distribute.MirroredStrategy()
set_global_policy("mixed_float16")

In [None]:
# Download Files 
from tornado_helper import TorNet, GOES
tornet = TorNet()
tornet.download(YEARS)

goes = GOES()
goes.download(YEARS)

Aria Download:   0%|          | 0.00/100G [00:00<?, ?B/s]

Aria Download:   0%|          | 0.00/14.7G [00:00<?, ?B/s]

['data_goes/2017/train/2017/NUL_170402_013436_KSJT_1078362n_V2.nc',
 'data_goes/2017/train/2017/NUL_170420_213054_KIND_1078518n_J1.nc',
 'data_goes/2017/train/2017/NUL_170621_143122_KLIX_1079126n_W4.nc',
 'data_goes/2017/train/2017/WRN_170810_190336_KGLD_1079367n_I5.nc',
 'data_goes/2017/train/2017/WRN_170828_205033_KHGX_1079613n_S4.nc',
 'data_goes/2017/train/2017/WRN_170905_184606_KLWX_1079715n_R9.nc',
 'data_goes/2017/train/2017/NUL_170522_191209_KJAX_1078864n_S3.nc',
 'data_goes/2017/train/2017/WRN_170327_055300_KSHV_1078298n_Q0.nc',
 'data_goes/2017/train/2017/NUL_170618_194948_KBUF_695190s_R6.nc',
 'data_goes/2017/train/2017/WRN_170623_010533_KLZK_1079167n_I7.nc',
 'data_goes/2017/train/2017/NUL_170802_175753_KDIX_705285s_Y7.nc',
 'data_goes/2017/train/2017/NUL_170612_145048_KFSD_712377s_P4.nc',
 'data_goes/2017/train/2017/NUL_170814_033528_KUEX_718028s_L4.nc',
 'data_goes/2017/train/2017/WRN_170624_110324_KDIX_1079187n_M6.nc',
 'data_goes/2017/train/2017/NUL_171008_191215_KGSP_1

In [10]:
tc = tornet.catalog()
tc.to_csv(f'{TORNET_ROOT}/catalog.csv')

In [None]:
import os
import numpy as np
import tensorflow as tf
import xarray as xr
from huggingface_hub import hf_hub_download
from tornet.models.keras.layers import CoordConv2D
from tornet.data.loader import query_catalog, TornadoDataLoader
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

os.makedirs(TF_ROOT, exist_ok=True)

# --- Load base model ---
model_path = hf_hub_download("tornet-ml/tornado_detector_baseline_v1", "tornado_detector_baseline.keras")
base = tf.keras.models.load_model(model_path, custom_objects={"CoordConv2D": CoordConv2D}, compile=False)
base.trainable = False

model_vars = [inp.name.split(":")[0] for inp in base.inputs]

# --- TFRecord Serializer ---
def serialize_example(inputs, label):
    feature = {k: tf.train.Feature(bytes_list=tf.train.BytesList(value=[v.tobytes()])) for k, v in inputs.items()}
    feature["label"] = tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# --- Full safe loader (single threaded) ---
def load_full_sample(r_path, g_path, year, available_vars):
    try:
        with xr.open_dataset(g_path) as ds_g:
            bands = [ds_g[b].values.astype(np.float32) for b in GOES_BANDS if b in ds_g]
            if len(bands) == 0:
                return None
            goes = np.stack(bands, axis=-1)
            if goes.size == 0:
                return None
            goes = np.nan_to_num(goes, nan=0.0)
            goes = tf.image.resize(goes, (TARGET_H, TARGET_W)).numpy()
    except Exception:
        return None

    try:
        radar_loader = TornadoDataLoader([r_path], variables=available_vars, n_frames=1, shuffle=False)
        radar_sample = next(iter(radar_loader))
    except Exception:
        return None

    inputs = {}
    for v in model_vars:
        if v in radar_sample:
            arr = radar_sample[v]
            if arr.ndim == 2:
                arr = np.expand_dims(arr, -1)
            if arr.ndim == 4:
                arr = arr[0]
            arr = tf.image.resize(arr, (TARGET_H, TARGET_W)).numpy()
            arr = np.nan_to_num(arr, nan=0.0)
        else:
            arr = np.zeros((TARGET_H, TARGET_W, 2), dtype=np.float32)
        inputs[v] = arr

    inputs["goes_input"] = goes
    return inputs

# --- Write function ---
def write_tfrecord(name, paths):
    samples = []
    for p in paths:
        rel = os.path.relpath(p, TORNET_ROOT)
        year = rel.split(os.sep)[1]
        g_path = os.path.join(GOES_ROOT, year, *rel.split(os.sep))
        if os.path.exists(g_path):
            samples.append((p, g_path, year))

    labels = np.array([1 if os.path.basename(p[0]).startswith("TOR") else 0 for p in samples], dtype=np.float32)

    with xr.open_dataset(samples[0][0]) as ds:
        available_vars = [v for v in model_vars if v in ds.variables]

    loaded_samples = []

    print("Loading full dataset into RAM...")
    for (r_path, g_path, year), label in tqdm(zip(samples, labels), total=len(samples), desc="Loading Samples"):
        loaded = load_full_sample(r_path, g_path, year, available_vars)
        if loaded is not None:
            loaded_samples.append((loaded, label, year))

    print(f"Loaded {len(loaded_samples)} samples into RAM.")

    print("Saving TFRecords in parallel...")

    # Now parallel write
    def write_one(index_tuple):
        idx, (inputs, label, year) = index_tuple
        year_dir = os.path.join(TF_ROOT, name, year)
        os.makedirs(year_dir, exist_ok=True)

        tfrecord_path = os.path.join(year_dir, f"{idx:05d}.tfrecord")
        serialized = serialize_example(inputs, label)
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            writer.write(serialized)

    with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
        list(tqdm(executor.map(write_one, enumerate(loaded_samples)), total=len(loaded_samples), desc="Writing TFRecords"))

    print(f"Finished writing {name} set.")

# --- Run ---
train_paths = query_catalog(TORNET_ROOT, "train", YEARS, random_state=42)
test_paths = query_catalog(TORNET_ROOT, "test", YEARS, random_state=42)

write_tfrecord("train", train_paths)
write_tfrecord("test", test_paths)

Loading full dataset into RAM...


Loading Samples: 100%|██████████| 14234/14234 [18:11<00:00, 13.04it/s]


Loaded 14234 samples into RAM.
Saving TFRecords in parallel...


Writing TFRecords: 100%|██████████| 14234/14234 [00:09<00:00, 1579.12it/s]


Finished writing train set.
Loading full dataset into RAM...


Loading Samples: 100%|██████████| 2840/2840 [03:40<00:00, 12.85it/s]


Loaded 2840 samples into RAM.
Saving TFRecords in parallel...


Writing TFRecords: 100%|██████████| 2840/2840 [00:02<00:00, 1155.20it/s]

Finished writing test set.





In [None]:
import os
import numpy as np
import tensorflow as tf
import xarray as xr
from huggingface_hub import hf_hub_download
from tornet.models.keras.layers import CoordConv2D
from tornet.data.loader import query_catalog, TornadoDataLoader
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, GlobalAveragePooling2D, Dropout, Dense, Concatenate, Add
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.losses import BinaryFocalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import AUC, Precision, Recall

# --- Load base model ---
model_path = hf_hub_download("tornet-ml/tornado_detector_baseline_v1", "tornado_detector_baseline.keras")
base = tf.keras.models.load_model(model_path, custom_objects={"CoordConv2D": CoordConv2D}, compile=False)
base.trainable = False

model_vars = [inp.name.split(":")[0] for inp in base.inputs]

# --- TFRecord parser ---
def parse_example(example_proto):
    feature_description = {v: tf.io.FixedLenFeature([], tf.string) for v in model_vars}
    feature_description["goes_input"] = tf.io.FixedLenFeature([], tf.string)
    feature_description["label"] = tf.io.FixedLenFeature([], tf.float32)

    parsed = tf.io.parse_single_example(example_proto, feature_description)

    inputs = {}
    for v in model_vars:
        inputs[v] = tf.io.decode_raw(parsed[v], tf.float32)
        inputs[v] = tf.reshape(inputs[v], (TARGET_H, TARGET_W, 2))
    inputs["goes_input"] = tf.io.decode_raw(parsed["goes_input"], tf.float32)
    inputs["goes_input"] = tf.reshape(inputs["goes_input"], (TARGET_H, TARGET_W, len(GOES_BANDS)))

    return inputs, parsed["label"]

# --- Count number of examples ---
def count_examples(files):
    count = 0
    for f in files:
        raw_dataset = tf.data.TFRecordDataset(f)
        for _ in raw_dataset:
            count += 1
    return count

# --- Prepare datasets ---
train_files = tf.io.gfile.glob(f"{TF_ROOT}/train/*/*.tfrecord")
val_files = tf.io.gfile.glob(f"{TF_ROOT}/test/*/*.tfrecord")

train_paths = query_catalog(TORNET_ROOT, "train", YEARS, random_state=42)
test_paths = query_catalog(TORNET_ROOT, "test", YEARS, random_state=42)

pos_count = sum(1 for p in train_paths if os.path.basename(p).startswith("TOR"))
neg_count = len(train_paths) - pos_count

# Class weights
n_train = len(train_paths)
n_val = len(test_paths)

total = len(train_paths)
class_weight = {
    0: total / (2 * neg_count),
    1: total / (2 * pos_count)
}

train_ds = tf.data.TFRecordDataset(train_files)
train_ds = train_ds.shuffle(n_train)         # Shuffle first
train_ds = train_ds.repeat()               # Then repeat
train_ds = train_ds.map(parse_example, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.prefetch(AUTOTUNE)

val_ds = tf.data.TFRecordDataset(val_files)
val_ds = val_ds.map(parse_example, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE)
val_ds = val_ds.prefetch(AUTOTUNE)
val_ds = val_ds.repeat()

steps_per_epoch = n_train // BATCH_SIZE
validation_steps = n_val // BATCH_SIZE

print(f"Training samples: {n_train}, Validation samples: {n_val}")
print(f"Steps per epoch: {steps_per_epoch}, Validation steps: {validation_steps}")

# --- Model ---
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    # GOES input branch (super light CNN)
    goes_input = Input(shape=(TARGET_H, TARGET_W, len(GOES_BANDS)), name="goes_input", dtype=tf.float32)
    
    # First (and only) block
    x = Conv2D(8, (3, 3), padding="same", activation="relu")(goes_input)
    x = BatchNormalization()(x)

    # Skip heavy stacking
    goes_features = GlobalAveragePooling2D()(x)
    goes_features = Dropout(0.2)(goes_features)
    
    # Radar inputs branch
    radar_inputs = [Input((TARGET_H, TARGET_W, 2), name=v) for v in model_vars]
    radar_features = base({v: inp for v, inp in zip(model_vars, radar_inputs)})
    
    # Fusion + tiny classifier
    fused = Concatenate()([radar_features, goes_features])
    y = Dense(32, activation="relu")(fused)
    y = Dropout(0.3)(y)
    out = Dense(1, activation="sigmoid")(y)

    model = Model(inputs={**{v: inp for v, inp in zip(model_vars, radar_inputs)}, "goes_input": goes_input}, outputs=out)
    model.compile(
        optimizer=Adam(1e-4),
        loss=BinaryFocalCrossentropy(alpha=0.75, gamma=2.0, label_smoothing=0.0),
        metrics=["accuracy", AUC(name="auc"), Precision(name="precision"), Recall(name="recall")]
    )

# --- Train ---
early_stop = EarlyStopping(monitor='val_precision', patience=15, restore_best_weights=True)
lr_scheduler = ReduceLROnPlateau(monitor='val_precision', factor=0.5, patience=5, min_lr=1e-6)

model.fit(
    train_ds,
    epochs=50,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=[early_stop, lr_scheduler],
    verbose=1,
    class_weight=class_weight
)

from sklearn.metrics import classification_report, confusion_matrix, f1_score

y_true = []
y_pred = []

for batch in val_ds.take(validation_steps):
    inputs, labels = batch
    preds = model.predict_on_batch(inputs)
    preds = preds.flatten()

    y_true.extend(labels.numpy())
    y_pred.extend(preds)

y_pred_binary = [1 if p >= 0.5 else 0 for p in y_pred]

print(classification_report(y_true, y_pred_binary, digits=4))
print(confusion_matrix(y_true, y_pred_binary))
print("F1 Score:", f1_score(y_true, y_pred_binary))

Training samples: 112209, Validation samples: 21495
Steps per epoch: 3506, Validation steps: 671
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


2025-04-26 17:58:56.592153: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] fused(ShuffleDatasetV3:4501985,RepeatDataset:4501986): Filling up shuffle buffer (this may take a while): 22307 of 112209
2025-04-26 17:59:06.592089: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] fused(ShuffleDatasetV3:4501985,RepeatDataset:4501986): Filling up shuffle buffer (this may take a while): 44956 of 112209
2025-04-26 17:59:16.592430: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] fused(ShuffleDatasetV3:4501985,RepeatDataset:4501986): Filling up shuffle buffer (this may take a while): 67543 of 112209
2025-04-26 17:59:36.328471: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 1/50
[1m3506/3506[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 15ms/step - accuracy: 0.4443 - auc: 0.4990 - loss: 0.1910 - precision: 0.0668 - recall: 0.5639 - val_accuracy: 0.2499 - val_auc: 0.6375 - val_loss: 0.2107 - val_precision: 0.0536 - val_recall: 0.8907 - learning_rate: 1.0000e-04
Epoch 2/50
[1m3506/3506[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 15ms/step - accuracy: 0.5991 - auc: 0.6023 - loss: 0.1706 - precision: 0.0905 - recall: 0.5434 - val_accuracy: 0.6251 - val_auc: 0.6342 - val_loss: 0.1637 - val_precision: 0.0711 - val_recall: 0.5750 - learning_rate: 1.0000e-04
Epoch 3/50
[1m3506/3506[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 15ms/step - accuracy: 0.6158 - auc: 0.6177 - loss: 0.1686 - precision: 0.0966 - recall: 0.5530 - val_accuracy: 0.5206 - val_auc: 0.6355 - val_loss: 0.1788 - val_precision: 0.0660 - val_recall: 0.6939 - learning_rate: 1.0000e-04
Epoch 4/50
[1m3506/3506[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s

In [None]:
model.save("full5.keras")
model.summary()