# File splits

This notebook is used to create the file lists defining the training, validation, and testing splits of the SPR dataset.

In [56]:
%load_ext autoreload
%autoreload 2
    
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [100]:
from ipwgml.utils import get_median_time
ref_sensor = "gmi"
geometry = "on_swath"
domain = "conus"

files = sorted(list(Path(f"/home/simon/data/ipwgml_new/spr_new/training_testing_validation/{ref_sensor}/{geometry}/").glob("**/*_202001*.nc")))

ancillary_files = {
    get_median_time(path): path for path in files if path.name.startswith("ancillary")
}
obs_files = {
    get_median_time(path): path for path in files if path.name.startswith(f"{ref_sensor}")
}
target_files = {
    get_median_time(path): path for path in files if path.name.startswith("target")
}
len(ancillary_files)

17

In [101]:
test_days = [1, 2, 3]
validation_days = [3, 4, 5]

training_files = {}
validation_files = {}
test_files = {}

times = set.intersection(set(ancillary_files.keys()), set(obs_files.keys()), set(target_files.keys()))
for time in times:
    if time.hour in test_days:
        test_files.setdefault(f"{ref_sensor}", []).append(obs_files[time])
        test_files.setdefault("ancillary", []).append(ancillary_files[time])
        test_files.setdefault("target_files", []).append(target_files[time])
    elif time.hour in validation_days:
        validation_files.setdefault(f"{ref_sensor}", []).append(obs_files[time])
        validation_files.setdefault("ancillary", []).append(ancillary_files[time])
        validation_files.setdefault("target_files", []).append(target_files[time])
    else:
        training_files.setdefault(f"{ref_sensor}", []).append(obs_files[time])
        training_files.setdefault("ancillary", []).append(ancillary_files[time])
        training_files.setdefault("target_files", []).append(target_files[time])

## Organize files
### Training, validation, and testing splits

In [102]:
target_folder = Path(f"/home/simon/data/ipwgml_test/spr/")

In [103]:
import shutil

splits = {
    "training": training_files,
    "validation": validation_files,
    "testing": test_files
}

for split, files in splits.items():
    times = np.array([get_median_time(path) for path in next(iter(files.values()))])
    subsets = np.logspace(-2, 0, 5)
    indices = np.random.permutation(times.size)
    time_splits = []
    start = 0
    for frac in subsets:
        n_samples = max(1, int(frac * times.size))
        subset_indices = indices[start:start + n_samples]
        time_splits.append(subset_indices)
        start += n_samples

    assert set(range(times.size)) == set.union(*[set(time_split) for time_split in time_splits])

    for subset, inds in zip(["xs", "s", "m", "l", "xl"], time_splits):
        for source, source_files in files.items():
            source_files = np.array(source_files)
            
            for path in source_files[inds]:
                time = get_median_time(path)
                year = time.year
                month = time.month
                day = time.day
                output_path = target_folder / ref_sensor / split / subset / geometry / f"{year:04}" / f"{month:02}" / f"{day:02}"
                output_path.mkdir(exist_ok=True, parents=True)
                shutil.copyfile(path, output_path / path.name)
        
        


### Evaluation files

In [106]:
%ls /home/simon/data/ipwgml_new/spr_new/evaluation

[0m[01;34mgridded[0m/  [01;34mon_swath[0m/


In [127]:
from ipwgml.utils import get_median_time
ref_sensor = "gmi"
geometry = "on_swath"
domain = "conus"

files = sorted(list(Path(f"/home/simon/data/ipwgml_new/spr_new/evaluation/{geometry}").glob("**/*.nc")))

ancillary_files = {
    get_median_time(path): path for path in files if path.name.startswith("ancillary")
}
obs_files = {
    get_median_time(path): path for path in files if path.name.startswith(f"{ref_sensor}")
}
target_files = {
    get_median_time(path): path for path in files if path.name.startswith("target")
}
len(target_files)

8

In [128]:
import shutil

split = "evaluation"
files = {
   f"{ref_sensor}": list(obs_files.values()),
    "ancillary": list(ancillary_files.values()),
    "target": list(target_files.values())
}

times = np.array([get_median_time(path) for path in next(iter(files.values()))])
subsets = np.logspace(-2, 0, 5)
indices = np.random.permutation(times.size)
time_splits = []
start = 0
for frac in subsets:
    n_samples = max(1, int(frac * times.size))
    subset_indices = indices[start:start + n_samples]
    time_splits.append(subset_indices)
    start += n_samples

assert set(range(times.size)) == set.union(*[set(time_split) for time_split in time_splits])

for source, source_files in files.items():
    source_files = np.array(source_files)
    
    for path in source_files[inds]:
        time = get_median_time(path)
        year = time.year
        month = time.month
        day = time.day
        output_path = target_folder / ref_sensor / split / domain / geometry / f"{year:04}" / f"{month:02}" / f"{day:02}"
        output_path.mkdir(exist_ok=True, parents=True)
        shutil.copyfile(path, output_path / path.name)



### Check consistency

The ancillary, target, and gmi files in all folders must correspond to the same median times.

In [129]:
def check_consistency(path: Path):
    """
    Check consistency of files in given folder.
    """
    anc_times = [get_median_time(fle) for fle in path.glob("ancillary*.nc")]
    gmi_times = [get_median_time(fle) for fle in path.glob("gmi*.nc")]
    target_times = [get_median_time(fle) for fle in path.glob("target*.nc")]

    assert set(anc_times) == set(gmi_times)
    assert set(anc_times) == set(target_times)

for split in splits:
    for size in ["xs", "s", "m", "l", "xl"]:
        check_consistency(target_folder / "gmi" / "gridded" / split / size / "2020" / "01" / "01")
    

## 

## Get files

In [130]:
%env IPWGML_DATA_PATH=/home/simon/data/ipwgml_test

env: IPWGML_DATA_PATH=/home/simon/data/ipwgml_test


In [131]:
from ipwgml.data import get_files

get_files(reference_sensor="gmi", geometry="gridded", split="training", subset="xl")

{'gmi': [PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/xs/gridded/2020/01/01/gmi_20200101105109.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/s/gridded/2020/01/01/gmi_20200101091316.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/m/gridded/2020/01/01/gmi_20200101122157.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/l/gridded/2020/01/01/gmi_20200101073914.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/l/gridded/2020/01/01/gmi_20200101091548.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/l/gridded/2020/01/01/gmi_20200101104346.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/xl/gridded/2020/01/01/gmi_20200101091029.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/xl/gridded/2020/01/01/gmi_20200101104548.nc'),
  PosixPath('/home/simon/data/ipwgml_test/spr/gmi/training/xl/gridded/2020/01/01/gmi_20200101104754.nc'),
  PosixPath('/home/simon/data/ipwgml_test/sp

In [141]:
files = {}
for reference_sensor in ["gmi"]:
    files[reference_sensor] = {}
    for split in ["training", "validation", "testing", "evaluation"]:
        files[reference_sensor][split] = {}
        if split != "evaluation":
            for subset in ["xs", "s", "m", "l", "xl"]:
                files[reference_sensor][split][subset] = {}
                for geometry in ["gridded", "on_swath"]:
                    source_files = get_files(
                        reference_sensor=reference_sensor,
                        geometry=geometry,
                        split=split,
                        subset=subset,
                        relative_to="/home/simon/data/ipwgml_test"
                    )
                    source_files = {name: [str(path) for path in fls] for name, fls in source_files.items()}
                    files[reference_sensor][split][domain] = source_files
                    
        else:
            for domain in ["conus", "korea", "austria"]:
                files[reference_sensor][split][domain] = {}
                for geometry in ["gridded", "on_swath"]:
                    source_files = get_files(
                        reference_sensor=reference_sensor,
                        geometry=geometry,
                        split=split,
                        domain=domain,
                        relative_to="/home/simon/data/ipwgml_test"
                    )
                    source_files = {name: [str(path) for path in fls] for name, fls in source_files.items()}
                    files[reference_sensor][split][domain] = source_files
                    

In [142]:
files

{'gmi': {'training': {'xs': {},
   'austria': {'gmi': ['spr/gmi/training/xs/on_swath/2020/01/01/gmi_20200101074019.nc',
     'spr/gmi/training/xs/on_swath/2020/01/01/gmi_20200101104417.nc',
     'spr/gmi/training/xs/on_swath/2020/01/01/gmi_20200101121953.nc',
     'spr/gmi/training/xs/on_swath/2020/01/01/gmi_20200101122304.nc',
     'spr/gmi/training/s/on_swath/2020/01/01/gmi_20200101074019.nc',
     'spr/gmi/training/s/on_swath/2020/01/01/gmi_20200101091219.nc',
     'spr/gmi/training/s/on_swath/2020/01/01/gmi_20200101104949.nc',
     'spr/gmi/training/s/on_swath/2020/01/01/gmi_20200101122304.nc',
     'spr/gmi/training/m/on_swath/2020/01/01/gmi_20200101091607.nc',
     'spr/gmi/training/m/on_swath/2020/01/01/gmi_20200101104949.nc',
     'spr/gmi/training/m/on_swath/2020/01/01/gmi_20200101121953.nc',
     'spr/gmi/training/l/on_swath/2020/01/01/gmi_20200101074019.nc',
     'spr/gmi/training/l/on_swath/2020/01/01/gmi_20200101091219.nc',
     'spr/gmi/training/l/on_swath/2020/01/01/gmi_

In [145]:
target_folder

PosixPath('/home/simon/data/ipwgml_test/spr')

In [149]:
import json
with open(target_folder.parent / "files.json", "w") as output:
    output.write(json.dumps(files, indent=2))

In [None]:
json.lo

In [147]:
target_folder 


PosixPath('/home/simon/data/ipwgml_test/spr')