In [1]:
import time
from ml_collections import config_dict
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import WandbLogger
from transformers import Wav2Vec2Config
from seisLM.model import LitWav2Vec2
from seisLM.data_pipeline import collator
import seisbench
import seisbench.data as sbd
import seisbench.generate as sbg
from seisbench.util import worker_seeding

# from earthquakeLM.utils import datadir

In [2]:
model_name_or_path = "patrickvonplaten/wav2vec2-base-v2"

model_config = Wav2Vec2Config.from_pretrained(model_name_or_path)
model_config.conv_dim = [a//8 for a in model_config.conv_dim]
model_config.conv_stride = [a * 2 for a in model_config.conv_stride]
model_config.conv_kernel = [a * 2 for a in model_config.conv_kernel]
model_config.num_attention_heads = 8
model_config.diversity_loss_weight = 0.15


training_config = config_dict.ConfigDict()
training_config.mask_time_prob = 0.65
training_config.mask_time_length = 10
training_config.global_batch_size = 4
training_config.seed = 42
training_config.warmup_frac_step = 0.2
training_config.learning_rate = 1e-4
training_config.weight_decay = 1e-4
training_config.num_train_epochs = 20
training_config.adam_beta1 = 0.9
training_config.adam_beta2 = 0.999
training_config.adam_epsilon = 1e-8
training_config.max_gumbel_temperature = 2.0
training_config.min_gumbel_temperature = 0.5
training_config.log_every_n_steps = 100
training_config.logger_project_name = 'seisLM'
training_config.num_workers = 1
training_config.model_save_dir = \
  '/home/liu0003/Desktop/projects/seisLM/saved_models'
training_config.num_train_fraction = 0.8
training_config.num_val_fraction = 0.1
training_config.num_test_fraction = 0.1
training_config.precision = "32"
training_config.gpu_devices = [0, 1]
seed_everything(training_config.seed)


Seed set to 42


42

In [3]:
model = LitWav2Vec2(model_config, training_config)


data = sbd.STEAD()
mask = data.metadata["trace_category"] != 'noise'  # Only select events with magnitude above 2.5
data.filter(mask)
train, dev, test = data.train_dev_test()
train_generator = sbg.GenericGenerator(train)
val_generator = sbg.GenericGenerator(dev)

augmentations = [
    sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
train_generator.add_augmentations(augmentations)
val_generator.add_augmentations(augmentations)





In [4]:

data_collator = \
  collator.DataCollatorForWav2Vec2PretrainingConcatChannelsNoPadding(
      model=model.model,
      mask_time_prob=training_config.mask_time_prob,
      mask_time_length=training_config.mask_time_length,
  )

dataloaders = {
  'train': DataLoader(
    train_generator, batch_size=training_config.global_batch_size, shuffle=True,
    num_workers=training_config.num_workers, worker_init_fn=worker_seeding,
    collate_fn=data_collator,
    ),
  'val': DataLoader(
    val_generator, batch_size=training_config.global_batch_size, shuffle=False,
    num_workers=training_config.num_workers, worker_init_fn=worker_seeding,
    collate_fn=data_collator,
    ),
}


In [5]:
next(iter(dataloaders['train']))

sample [{'X': array([[-0.00250112, -0.00242765, -0.00232805, ..., -0.00258183,
        -0.00250649, -0.00250112],
       [-0.0002024 , -0.00018344, -0.0001075 , ..., -0.00026062,
        -0.00020431, -0.0002024 ],
       [-0.0007699 , -0.00098384, -0.00133656, ..., -0.00073578,
        -0.00077023, -0.0007699 ]], dtype=float32)}, {'X': array([[ 4.8519173e-04,  1.7385147e-04, -3.7199727e-04, ...,
         5.0855981e-04,  4.9532048e-04,  4.8519173e-04],
       [ 1.1266283e-03,  6.7102473e-04, -5.7241112e-05, ...,
         1.1602479e-03,  1.1231112e-03,  1.1266283e-03],
       [-5.6192646e-04, -5.1241036e-04, -4.5429208e-04, ...,
        -5.0737435e-04, -5.4877612e-04, -5.6192646e-04]], dtype=float32)}, {'X': array([[ 9.9496015e-07,  5.9140109e-07, -3.6345948e-06, ...,
        -1.4180571e-06,  3.7517412e-07,  9.9496015e-07],
       [ 5.9281234e-07, -6.0712568e-07,  1.2508367e-07, ...,
        -4.9233995e-07,  2.7009852e-07,  5.9281234e-07],
       [-9.8071723e-07, -2.8264001e-06, -4.29350

sample [{'X': array([[-2.78615829e-04, -2.79700849e-04, -2.84453505e-04, ...,
        -2.85115675e-04, -2.80331180e-04, -2.78615829e-04],
       [ 2.09014815e-05,  1.35420105e-05,  1.69109990e-06, ...,
         1.72002237e-05,  1.98968446e-05,  2.09014815e-05],
       [-8.49768185e-06,  1.67743011e-08,  1.40911225e-05, ...,
        -1.55234684e-05, -1.01339574e-05, -8.49768185e-06]], dtype=float32)}, {'X': array([[ 1.9627585e-05,  1.6639333e-05,  1.1219428e-05, ...,
         2.1407373e-05,  1.9687899e-05,  1.9627585e-05],
       [ 6.7946380e-06,  5.9320623e-06,  4.2022662e-06, ...,
         8.9733649e-06,  7.2297394e-06,  6.7946380e-06],
       [-1.4328667e-05, -1.4711262e-05, -1.4691180e-05, ...,
        -1.4785780e-05, -1.4176428e-05, -1.4328667e-05]], dtype=float32)}, {'X': array([[-5.8528025e-07, -5.2913742e-07, -2.3162794e-07, ...,
         2.3721846e-06,  3.3994432e-08, -5.8528025e-07],
       [-3.6013242e-07, -3.3963201e-07, -3.2642905e-07, ...,
         6.9608558e-07, -5.041637

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/liu0003/miniconda3/envs/seisbench/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/liu0003/miniconda3/envs/seisbench/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/liu0003/Desktop/projects/seisLM/seisLM/data_pipeline/collator.py", line 61, in __call__
    batch_size, seq_length = features.shape
ValueError: too many values to unpack (expected 2)


In [None]:


# # Prepare data collator and dataloader
# dataset = costa_rica_dataset.EarthquakeRecordingDataset(
#     training_config.dataset_path
# )


# data_collator = \
#   costa_rica_dataset.DataCollatorForWav2Vec2PretrainingConcatChannelsNoPadding(
#       model=model.model,
#       mask_time_prob=training_config.mask_time_prob,
#       mask_time_length=training_config.mask_time_length,
#       normalization_type=training_config.data_normalization_type,
#   )

# dataloaders = costa_rica_dataset.get_dataloaders_costa_rica(
#     dataset,
#     data_collator=data_collator,
#     train_batch_size=training_config.global_batch_size,
#     eval_batch_size=training_config.global_batch_size,
#     num_train_fraction=training_config.num_train_fraction,
#     num_val_fraction=training_config.num_val_fraction,
#     num_test_fraction=training_config.num_test_fraction,
#     num_workers=training_config.num_workers
# )

training_config.max_train_steps = training_config.num_train_epochs * len(
  dataloaders['train'] )


# Training loop
checkpoint_callback = ModelCheckpoint(
    monitor='val/loss',
    save_top_k=1,
    mode='min',
    save_last=True,
)

lr_monitor = LearningRateMonitor(logging_interval='step')

formatted_time = time.strftime(
  "%Y-%m-%d-%Hh-%Mm-%Ss", time.localtime(time.time())
)
run_name = f"{training_config.seed}__{formatted_time}"

logger = WandbLogger(
    project=training_config.logger_project_name,
    save_dir=training_config.model_save_dir,
    name=run_name,
    id=run_name,
)



logger.log_hyperparams(model.hparams)
logger.log_hyperparams(training_config.to_dict())


trainer = L.Trainer(
    profiler='simple',
    logger=logger,
    log_every_n_steps=training_config.log_every_n_steps,
    devices=training_config.gpu_devices,
    accelerator='gpu',
    strategy='ddp',
    max_epochs=training_config.num_train_epochs,
    callbacks=[
      checkpoint_callback, lr_monitor,
    ],
    default_root_dir=training_config.model_save_dir,
    precision=training_config.precision,
)


trainer.fit(
    model,
    train_dataloaders=dataloaders['train'],
    val_dataloaders=dataloaders['val']
)

