In [16]:
import os
import sys
import torch
import numpy as np
from tqdm import tqdm

In [None]:
LEAD_TIME = sys.argv[1]
PARTITION = sys.argv[2]

BASE_DIR = "/gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison"

RAW_INPUT_DIR  = f"{BASE_DIR}/raw/inputs_t0"
RAW_TARGET_DIR = f"{BASE_DIR}/raw/targets_t{LEAD_TIME}"

SPLIT_FILE = f"{BASE_DIR}/splits/{PARTITION}_files.txt"
SHARDS_DIR = f"{BASE_DIR}/shards/t{LEAD_TIME}/{PARTITION}_t{LEAD_TIME}" 

os.makedirs(SHARDS_DIR, exist_ok=True)

In [18]:
with open(SPLIT_FILE) as f:
    files = [line.strip() for line in f if line.strip()]

In [None]:
FILES_PER_SHARD = 1000

In [None]:
shard_inputs, shard_globals, shard_targets, shard_ids = [], [], [], []
shard_index = 0

# Main loop
for i, fpath in enumerate(tqdm(files[:], desc="Sharding inputs + targets")):
    try:
        # Load input
        data_in = torch.load(fpath, map_location="cpu")
        x = data_in["input_tensor"].numpy()
        g = data_in["global_context"].numpy()
        nowcast_id = data_in["nowcast_origin"]

        # Load target
        fname = os.path.basename(fpath).replace("input-", "target-")
        target_path = os.path.join(RAW_TARGET_DIR, fname)
        if not os.path.exists(target_path):
            print(f"Missing target file: {target_path}")
            continue

        data_out = torch.load(target_path, map_location="cpu")
        y = data_out["data"].numpy().astype(np.uint8)  # (350, 370)

        # Append to buffers
        shard_inputs.append(x)
        shard_globals.append(g)
        shard_targets.append(y)
        shard_ids.append(nowcast_id)

        # Save every N samples
        if (i + 1) % FILES_PER_SHARD == 0 or (i + 1) == len(files):
            shard_path = os.path.join(SHARDS_DIR, f"shard_{shard_index:03d}.pt")

            torch.save({
                "X": torch.tensor(np.stack(shard_inputs), dtype=torch.float32),
                "G": torch.tensor(np.stack(shard_globals), dtype=torch.float32),
                "Y": torch.tensor(np.stack(shard_targets), dtype=torch.uint8),
                "ID": shard_ids
            }, shard_path)

            print(f"Saved shard_{shard_index:03d} ({len(shard_inputs)} samples) → {shard_path}")

            # Reset buffers
            shard_inputs, shard_globals, shard_targets, shard_ids = [], [], [], []
            shard_index += 1

    except Exception as e:
        print(f"Error processing {fpath}: {e}")
        continue

print(f"Finished creating {PARTITION.upper()} shards for LT{LEAD_TIME}.")

Sharding inputs + targets:  60%|██████    | 12/20 [00:00<00:00, 15.71it/s]

Saved shard_000 (10 samples) → /gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison/shards/t1/train_t1/shard_000.pt


Sharding inputs + targets: 100%|██████████| 20/20 [00:01<00:00, 14.58it/s]

Saved shard_001 (10 samples) → /gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison/shards/t1/train_t1/shard_001.pt
Finished creating TRAIN shards for LT1.



