In [1]:
import re
import os
from collections import defaultdict

In [2]:
INPUT_DIR = "/gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison/raw/inputs_t0"
SAVE_DIR  = "/gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison/splits"
os.makedirs(SAVE_DIR, exist_ok=True)

In [3]:
def extract_year(filename):
    """
    Extracts 4-digit year from filenames like:
    input-20181230_1200.pt → 2018
    """
    match = re.search(r"(\d{4})\d{2}\d{2}_\d{4}", filename)
    if match:
        return int(match.group(1))
    else:
        return None

In [4]:
all_files = sorted([
    os.path.join(INPUT_DIR, f)
    for f in os.listdir(INPUT_DIR)
    if f.endswith(".pt")
])

In [5]:
len(all_files)

103771

In [6]:
by_year = defaultdict(list)
for f in all_files:
    year = extract_year(os.path.basename(f))
    if year is not None:
        by_year[year].append(f)

In [7]:
train_years = range(2004, 2019)
val_years   = [2019]
test_years  = range(2020, 2025)

In [8]:
train_files = [f for y in train_years for f in by_year.get(y, [])]
val_files   = [f for y in val_years   for f in by_year.get(y, [])]
test_files  = [f for y in test_years  for f in by_year.get(y, [])]

In [9]:
def save_list(file_list, name):
    path = os.path.join(SAVE_DIR, f"{name}_files.txt")
    with open(path, "w") as f:
        for item in sorted(file_list):
            f.write(f"{item}\n")
    print(f"Saved {name} split → {len(file_list):,} files")

save_list(train_files, "train")
save_list(val_files, "val")
save_list(test_files, "test")

Saved train split → 76,138 files
Saved val split → 4,872 files
Saved test split → 22,761 files


In [10]:
import torch

In [11]:
path = "/gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison/raw/inputs_t0"
files = [
    "input-20200215_1300.pt",
    "input-20201204_1515.pt",
    "input-20210102_0230.pt",
    "input-20230126_1900.pt",
]

for f in files:
    p = os.path.join(path, f)
    d = torch.load(p)
    print(f, d["input_tensor"].shape)

input-20200215_1300.pt torch.Size([288, 10])
input-20201204_1515.pt torch.Size([288, 10])
input-20210102_0230.pt torch.Size([288, 10])
input-20230126_1900.pt torch.Size([288, 10])


In [12]:
import torch, os

lead_time = "1"
base = f"/gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison/raw/targets_t{lead_time}"

check_files = [
    "target-20200215_1300.pt",
    "target-20201204_1515.pt",
    "target-20210102_0230.pt",
    "target-20230126_1900.pt",
]

for f in check_files:
    p = os.path.join(base, f)
    d = torch.load(p)
    print(f, d["data"].shape)


target-20200215_1300.pt torch.Size([350, 370])
target-20201204_1515.pt torch.Size([350, 370])
target-20210102_0230.pt torch.Size([350, 370])
target-20230126_1900.pt torch.Size([350, 370])


In [13]:
import torch, os, numpy as np

base = "/gws/nopw/j04/wiser_ewsa/mrakotomanga/Intercomparison/raw"
inputs_dir = f"{base}/inputs_t0"
targets_dir = f"{base}/targets_t2"

bad = []
for fname in os.listdir(inputs_dir):
    if not fname.endswith(".pt"): continue
    tname = fname.replace("input-", "target-")
    tpath = os.path.join(targets_dir, tname)
    if not os.path.exists(tpath):
        continue
    try:
        x = torch.load(os.path.join(inputs_dir, fname))["input_tensor"].numpy()
        y = torch.load(tpath)["data"].numpy()
        if x.shape != (288, 10) or y.shape != (350, 370):
            bad.append(fname)
    except Exception as e:
        bad.append(fname)
print("Bad samples:", len(bad))
print(bad[:20])

KeyboardInterrupt: 