In [1]:
# prepare_phase_pick_evaluation_loader
from seisLM.data_pipeline.dataloaders import get_dataset_by_name
import seisbench.generate as sbg
from pathlib import Path
import pandas as pd
from seisLM.model import supervised_models
from torch.utils.data import DataLoader


data_aliases = {
    "ethz": "ETHZ",
    "geofon": "GEOFON",
    "stead": "STEAD",
    "neic": "NEIC",
    "instance": "InstanceCountsCombined",
    "iquique": "Iquique",
    "lendb": "LenDB",
    "scedc": "SCEDC",
}



# model_name = 'PhaseNet'
# targets = '/scicore/home/dokman0000/liu0003/projects/seisLM/data/targets/ethz'
# # targets = '/scicore/home/dokman0000/liu0003/projects/seisLM/data/targets/instance'
# targets = Path(targets)
# cache = None

# sampling_rate=100

# dataset = get_dataset_by_name(data_aliases[targets.name])(
#     sampling_rate=sampling_rate, component_order="ZNE", dimension_order="NCW",
#     cache=cache,
# )

# model_cls = supervised_models.__getattribute__(model_name + "Lit")
# model = model_cls()

# batch_size = 1024
# num_workers = 8

In [13]:
def prepare_phase_pick_evaluation_loader(
  model,
  targets,
  batch_size, num_workers,
  sampling_rate=100,
  component_oder="ZNE",
  dimension_order="NCW",
  cache=None
  ):
  '''
  Taken from:
  https://github.com/seisbench/pick-benchmark/blob/main/benchmark/eval.py
  '''

  targets = Path(targets)
  loaders = {}

  dataset = get_dataset_by_name(data_aliases[targets.name])(
      sampling_rate=sampling_rate,
      component_order=component_oder,
      dimension_order=dimension_order,
      cache=cache,
  )

  for eval_set in ['dev', 'test']:
    split = dataset.get_split(eval_set)

    # There are some subtleties in the `instance` dataset
    # TODO: understand and explain this better.
    if targets.name == "instance":
      logging.warning(
          "Overwriting noise trace_names to allow correct identification"
      )
      # Replace trace names for noise entries
      split._metadata["trace_name"].values[
          -len(split.datasets[-1]) :
      ] = split._metadata["trace_name"][-len(split.datasets[-1]) :].apply(
          lambda x: "noise_" + x
      )
      split._build_trace_name_to_idx_dict()

    if cache:
      split.preload_waveforms(pbar=True)

    for task in ["1", "23"]:

      task_csv = targets / f"task{task}.csv"

      if not task_csv.is_file():
        continue

      task_targets = pd.read_csv(task_csv)
      task_targets = task_targets[task_targets["trace_split"] == eval_set]

      if task == "1" and targets.name == "instance":
        border = _identify_instance_dataset_border(task_targets)
        task_targets["trace_name"].values[border:] = task_targets["trace_name"][
            border:
        ].apply(lambda x: "noise_" + x)

      if sampling_rate != 100:
        for key in ["start_sample", "end_sample", "phase_onset"]:
          if key not in task_targets.columns:
              continue
          task_targets[key] = (
              task_targets[key]
              * sampling_rate
              / task_targets["sampling_rate"]
          )
        task_targets[sampling_rate] = sampling_rate

      generator = sbg.SteeredGenerator(split, task_targets)
      generator.add_augmentations(model.get_eval_augmentations())

      loader = DataLoader(
        generator, batch_size=batch_size, shuffle=False, num_workers=num_workers
      )
      loaders[f"task_{task}_{eval_set}"] = loader
  return loaders

In [14]:
model = supervised_models.PhaseNetLit()
loaders = prepare_phase_pick_evaluation_loader(
  model,
  targets='/scicore/home/dokman0000/liu0003/projects/seisLM/data/targets/ethz',
  batch_size=1024,
  num_workers=8,
)




In [12]:
loaders

{'task_1_dev': <torch.utils.data.dataloader.DataLoader at 0x7f70bb12b370>,
 'task_23_dev': <torch.utils.data.dataloader.DataLoader at 0x7f70bb12b2b0>,
 'task_1_test': <torch.utils.data.dataloader.DataLoader at 0x7f70bad91f10>,
 'task_23_test': <torch.utils.data.dataloader.DataLoader at 0x7f70bae9d0a0>}