In [16]:
import tqdm
import time
import math
import random
import torch
import rasterio
import importlib


import os
from pathlib import Path


PROJECT_FOLDER = Path(os.getenv("PROJECT_ROOT"))
DATA_FOLDER = PROJECT_FOLDER / "data"
src_path = PROJECT_FOLDER / "src"

import sys
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

In [17]:
import tortoise.dataset as dataset   # import the module
importlib.reload(dataset)            # reload the module

import tortoise.normalizer as normalizer   # import the module
importlib.reload(normalizer)      # reload the module

import tortoise.dataloader as dataloader   # import the module
importlib.reload(dataloader)            # reload the module


from tortoise.dataset import *
from tortoise.normalizer import *
from tortoise.dataloader import *

In [18]:
dn = DataNormalizer(DATA_FOLDER / "normalization_stats.json", preloaded=True)


In [19]:
from pathlib import Path
from tortoise.dataloader import build_dataloaders
from tortoise.model import U_Net

DATA_FOLDER = PROJECT_FOLDER / "data"
tiles_dir = DATA_FOLDER / "tiles"

# dn is your previously-computed normalizer
train_loader, val_loader, test_loader = build_dataloaders(
    tiles_dir=tiles_dir,
    batch_size= 128,
    normalizer=dn,
    seed=98,
    train_ratio=0.01,
    val_ratio=0.01,
    test_ratio=0.01,
    save_aug_map_path= DATA_FOLDER / "aug_map.json",
)

In [20]:
len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)

(582, 582, 582)

In [15]:
from tqdm import tqdm
def extract_tile_versions(loader):
    pairs = set()
    for batch in tqdm(loader):
        tids = batch["tile_id"]
        vers = batch["version"]
        for t, v in zip(tids, vers):
            pairs.add((t, v))
    return pairs

train_pairs = extract_tile_versions(train_loader)
val_pairs   = extract_tile_versions(val_loader)
test_pairs  = extract_tile_versions(test_loader)

# Convert to tile_id-only sets
train_ids = set(t for t, _ in train_pairs)
val_ids   = set(t for t, _ in val_pairs)
test_ids  = set(t for t, _ in test_pairs)

print("Train IDs:", len(train_ids))
print("Val IDs:", len(val_ids))
print("Test IDs:", len(test_ids))

# --- Check 1: No overlap ---
print("Train ∩ Val =", train_ids & val_ids)
print("Train ∩ Test =", train_ids & test_ids)
print("Val ∩ Test =", val_ids & test_ids)

# --- Check 2: All 3 versions exist in each split ---
def count_versions(pairs):
    from collections import defaultdict
    d = defaultdict(set)
    for tid, ver in pairs:
        d[tid].add(ver)
    return {tid: len(vers) for tid, vers in d.items()}

print("Train version counts (should all be 3):", set(count_versions(train_pairs).values()))
print("Val version counts (should all be 3):", set(count_versions(val_pairs).values()))
print("Test version counts (should all be 3):", set(count_versions(test_pairs).values()))

100%|██████████| 5/5 [00:11<00:00,  2.29s/it]
100%|██████████| 5/5 [00:11<00:00,  2.29s/it]
100%|██████████| 5/5 [00:11<00:00,  2.29s/it]

Train IDs: 194
Val IDs: 194
Test IDs: 194
Train ∩ Val = set()
Train ∩ Test = set()
Val ∩ Test = set()
Train version counts (should all be 3): {3}
Val version counts (should all be 3): {3}
Test version counts (should all be 3): {3}





In [27]:
batch = next(iter(train_loader))
print(batch.keys())
# print(batch["version"])
batch['ms']

dict_keys(['tile_id', 'version', 'ms', 'label', 'mask'])


tensor([[[[0.1214, 0.1214, 0.1214,  ..., 0.1333, 0.1333, 0.1333],
          [0.1214, 0.1214, 0.1214,  ..., 0.1337, 0.1337, 0.1333],
          [0.1214, 0.1214, 0.1214,  ..., 0.1337, 0.1337, 0.1337],
          ...,
          [0.1543, 0.1543, 0.1543,  ..., 0.1099, 0.1099, 0.1099],
          [0.1543, 0.1543, 0.1543,  ..., 0.1019, 0.1019, 0.1019],
          [0.1575, 0.1575, 0.1575,  ..., 0.1019, 0.1019, 0.1019]],

         [[0.0764, 0.0735, 0.0717,  ..., 0.0915, 0.0910, 0.0875],
          [0.0641, 0.0666, 0.0690,  ..., 0.0912, 0.0922, 0.0806],
          [0.0715, 0.1001, 0.0764,  ..., 0.0910, 0.0880, 0.0868],
          ...,
          [0.1198, 0.1233, 0.1181,  ..., 0.0589, 0.0614, 0.0634],
          [0.1109, 0.1156, 0.1230,  ..., 0.0587, 0.0668, 0.0648],
          [0.1082, 0.1146, 0.1255,  ..., 0.0604, 0.0666, 0.0710]],

         [[0.1007, 0.1026, 0.0897,  ..., 0.0985, 0.0980, 0.1012],
          [0.0885, 0.0869, 0.0869,  ..., 0.1010, 0.0985, 0.0906],
          [0.0962, 0.1232, 0.0894,  ..., 0