In [2]:
import re
import os
from collections import defaultdict

In [3]:
INPUT_DIR = "/work/scratch-nopw2/mendrika/OB/raw/inputs_t0"
SAVE_DIR  = "/home/users/mendrika/Object-Based-LSTMConv/outputs/data-split"
os.makedirs(SAVE_DIR, exist_ok=True)

In [4]:
def extract_year(filename):
    """
    Extracts 4-digit year from filenames like:
    input-20181230_1200.pt → 2018
    """
    match = re.search(r"(\d{4})\d{2}\d{2}_\d{4}", filename)
    if match:
        return int(match.group(1))
    else:
        return None

In [5]:
all_files = sorted([
    os.path.join(INPUT_DIR, f)
    for f in os.listdir(INPUT_DIR)
    if f.endswith(".pt")
])

In [6]:
len(all_files)

162169

In [7]:
by_year = defaultdict(list)
for f in all_files:
    year = extract_year(os.path.basename(f))
    if year is not None:
        by_year[year].append(f)

In [8]:
train_years = range(2004, 2019)
val_years   = [2019]
test_years  = range(2020, 2025)

In [9]:
train_files = [f for y in train_years for f in by_year.get(y, [])]
val_files   = [f for y in val_years   for f in by_year.get(y, [])]
test_files  = [f for y in test_years  for f in by_year.get(y, [])]

In [10]:
def save_list(file_list, name):
    path = os.path.join(SAVE_DIR, f"{name}_files.txt")
    with open(path, "w") as f:
        for item in sorted(file_list):
            f.write(f"{item}\n")
    print(f"Saved {name} split → {len(file_list):,} files")

save_list(train_files, "train")
save_list(val_files, "val")
save_list(test_files, "test")

Saved train split → 113,883 files
Saved val split → 7,227 files
Saved test split → 41,059 files
