In [4]:
import time
from ml_collections import config_dict
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import WandbLogger
from transformers import Wav2Vec2Config
from seisLM.model.foundation.pretrained_models import LitMultiDimWav2Vec2
from seisLM.data_pipeline import collator
import seisbench
import seisbench.data as sbd
import seisbench.generate as sbg
from seisbench.util import worker_seeding

In [5]:
model_name_or_path = "patrickvonplaten/wav2vec2-base-v2"

model_config = Wav2Vec2Config.from_pretrained(model_name_or_path)
# model_config.conv_dim = [a//8 for a in model_config.conv_dim]
# model_config.conv_stride = [a * 2 for a in model_config.conv_stride]
# model_config.conv_kernel = [a * 2 for a in model_config.conv_kernel]
model_config.num_attention_heads = 8
model_config.diversity_loss_weight = 0.15
model_config.input_dim = 3


training_config = config_dict.ConfigDict()
training_config.mask_time_prob = 0.65
training_config.mask_time_length = 10
training_config.global_batch_size = 4
training_config.seed = 42
training_config.warmup_frac_step = 0.2
training_config.learning_rate = 1e-4
training_config.weight_decay = 1e-4
training_config.num_train_epochs = 20
training_config.adam_beta1 = 0.9
training_config.adam_beta2 = 0.999
training_config.adam_epsilon = 1e-8
training_config.max_gumbel_temperature = 2.0
training_config.min_gumbel_temperature = 0.5
training_config.log_every_n_steps = 100
training_config.logger_project_name = 'seisLM'
training_config.num_workers = 1
training_config.model_save_dir = \
  '/home/liu0003/Desktop/projects/seisLM/saved_models'
training_config.num_train_fraction = 0.8
training_config.num_val_fraction = 0.1
training_config.num_test_fraction = 0.1
training_config.precision = "32"
training_config.gpu_devices = [0, 1]
seed_everything(training_config.seed)


model = LitMultiDimWav2Vec2(model_config, training_config)



Seed set to 42


In [7]:
model_config


Wav2Vec2Config {
  "activation_dropout": 0.0,
  "adapter_attn_dim": null,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForPreTraining"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 256,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": true,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "sum",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.15,
  "do_stable_layer_norm": true,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_dropout": 0.0,
  "feat_extract_norm": "layer",
  "feat_proj_dropout": 0.0,
  "feat_quantizer_dropout": 0.0,
  "final_dropout": 0.0,
  "hidden_act": "gelu",
  "hidden_dropo

In [21]:
model_config.conv_dim = [256, 256]
model_config.conv_kernel = [10, 3, 3]
model_config.conv_stride = [5, 2, 2]

# model_config.conv_kernel



In [22]:
# model_config.conv_dim = model_config.conv_dim[:3]# [512]
# model_config.conv_kernel = model_config.conv_kernel[:3]# = [1]
# model_config.conv_stride = model_config.conv_stride[:3] #[1]



In [23]:
model.model._get_feat_extract_output_lengths(3000)


tensor(149)

In [20]:
import numpy as np


data = sbd.STEAD(component_order='ZNE')
data.filter(data.metadata["trace_category"] != 'noise')
train, dev, test = data.train_dev_test()
train_generator = sbg.GenericGenerator(train)
val_generator = sbg.GenericGenerator(dev)

# Phase dict for labelling. We only study P and S phases without differentiating between them.
phase_dict = {
    "trace_p_arrival_sample": "P",
    "trace_pP_arrival_sample": "P",
    "trace_P_arrival_sample": "P",
    "trace_P1_arrival_sample": "P",
    "trace_Pg_arrival_sample": "P",
    "trace_Pn_arrival_sample": "P",
    "trace_PmP_arrival_sample": "P",
    "trace_pwP_arrival_sample": "P",
    "trace_pwPm_arrival_sample": "P",
    "trace_s_arrival_sample": "S",
    "trace_S_arrival_sample": "S",
    "trace_S1_arrival_sample": "S",
    "trace_Sg_arrival_sample": "S",
    "trace_SmS_arrival_sample": "S",
    "trace_Sn_arrival_sample": "S",
}


augmentations = [
    sbg.OneOf(
        [
            sbg.WindowAroundSample(
                list(phase_dict.keys()),
                samples_before=3000,
                windowlen=6000,
                selection="random",
                strategy="variable",
            ),
            sbg.NullAugmentation(),
        ],
        probabilities=[2, 1],
    ),
    sbg.RandomWindow(
        low=None,
        high=None,
        windowlen=3001,
        strategy="pad",
    ),
    sbg.ChangeDtype(np.float32),
    sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]

train_generator.add_augmentations(augmentations)
val_generator.add_augmentations(augmentations)


In [14]:
train_generator[0]['X'].shape

(3, 3001)