In [1]:
import time
from ml_collections import config_dict
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import WandbLogger
from transformers import Wav2Vec2Config
from seisLM.model.foundation.pretrained_models import LitMultiDimWav2Vec2
from seisLM.data_pipeline import collator
import seisbench
import seisbench.data as sbd
import seisbench.generate as sbg
from seisbench.util import worker_seeding
import json

In [2]:
model_config_path = '/scicore/home/dokman0000/liu0003/projects/seisLM/seisLM/configs/pretrain/model_config_4xdownsample_scale_logits_quantization.json'

model_config = Wav2Vec2Config.from_pretrained(
  model_config_path,
  attn_implementation="sdpa"
)


training_config_path = '/scicore/home/dokman0000/liu0003/projects/seisLM/seisLM/configs/pretrain/training_config.json'

with open(training_config_path, "r", encoding="utf-8") as f:
  training_config = json.load(f)
training_config = config_dict.ConfigDict(training_config)

model = LitMultiDimWav2Vec2(model_config, training_config)



In [4]:
import torch
model = torch.compile(model)

In [6]:
import numpy as np


data = sbd.STEAD(component_order='ZNE')
data.filter(data.metadata["trace_category"] != 'noise')
train, dev, test = data.train_dev_test()
train_generator = sbg.GenericGenerator(train)
val_generator = sbg.GenericGenerator(dev)

# Phase dict for labelling. We only study P and S phases without differentiating between them.
phase_dict = {
    "trace_p_arrival_sample": "P",
    "trace_pP_arrival_sample": "P",
    "trace_P_arrival_sample": "P",
    "trace_P1_arrival_sample": "P",
    "trace_Pg_arrival_sample": "P",
    "trace_Pn_arrival_sample": "P",
    "trace_PmP_arrival_sample": "P",
    "trace_pwP_arrival_sample": "P",
    "trace_pwPm_arrival_sample": "P",
    "trace_s_arrival_sample": "S",
    "trace_S_arrival_sample": "S",
    "trace_S1_arrival_sample": "S",
    "trace_Sg_arrival_sample": "S",
    "trace_SmS_arrival_sample": "S",
    "trace_Sn_arrival_sample": "S",
}


augmentations = [
    sbg.OneOf(
        [
            sbg.WindowAroundSample(
                list(phase_dict.keys()),
                samples_before=3000,
                windowlen=6000,
                selection="random",
                strategy="variable",
            ),
            sbg.NullAugmentation(),
        ],
        probabilities=[2, 1],
    ),
    sbg.RandomWindow(
        low=None,
        high=None,
        windowlen=3001,
        strategy="pad",
    ),
    sbg.ChangeDtype(np.float32),
    sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]

train_generator.add_augmentations(augmentations)
val_generator.add_augmentations(augmentations)


In [14]:
train_generator[0]['X'].shape

(3, 3001)