In [1]:
import json
import os
import ml_collections
import torch
import numpy as np
from lightning.pytorch import seed_everything
from seisLM.model.foundation.pretrained_models import LitMultiDimWav2Vec2
from seisLM.data_pipeline import collator
from seisLM.data_pipeline import pretrain_dataloaders as dataloaders
import matplotlib.pyplot as plt
import tqdm
from collections import defaultdict

DEFAULT_NUM_WORKERS = 4

config_path = '/scicore/home/dokman0000/liu0003/projects/seisLM/seisLM/configs/pretrain/pretrain_config_layernorm_std_small_batch_6_datasets.json'

def get_loaders(config_path):

  with open(config_path, "r", encoding="utf-8") as f:
    config = json.load(f)
  config = ml_collections.ConfigDict(config)
  config.data_config.local_batch_size = 16

  seed_everything(config.seed)
  model = LitMultiDimWav2Vec2(config)


  data_collator = \
    collator.DataCollatorForWav2Vec2PretrainingConcatChannelsNoPadding(
        config=config.model_config,
        mask_time_prob=config.training_config.mask_time_prob,
        mask_time_length=config.training_config.mask_time_length,
  )

  config.data_config.num_workers = int(
      os.environ.get('SLURM_CPUS_PER_TASK', DEFAULT_NUM_WORKERS)
  )

  seed_everything(42)
  train_loader, dev_loader = dataloaders.prepare_pretrain_dataloaders(
    model=model,
    training_fraction=config.data_config.training_fraction,
    data_names=['ETHZ'],
    missing_components='copy',
    batch_size=config.data_config.local_batch_size,
    num_workers=config.data_config.num_workers,
    prefetch_factor=config.data_config.prefetch_factor,
    collator=data_collator,
    cache=config.data_config.cache_dataset,
  )
  return train_loader, dev_loader['ETHZ']

In [2]:
train_loader, dev_loader = get_loaders(config_path)

Seed set to 42
Seed set to 42
Seed set to 42
Seed set to 42


In [3]:
with open(config_path, "r", encoding="utf-8") as f:
  config = json.load(f)

config = ml_collections.ConfigDict(config)
config.data_config.local_batch_size = 16


In [12]:
div_loss_unscaled = {}
div_loss_scaled = {}

for scale_logits in [True, False]:

  print(f"Scale logits: {scale_logits}")
  config.model_config.scale_logits_in_quantization = scale_logits

  for last_conv_dim in [1, 2, 4, 8, 16, 32, 64, 128]:
    config.model_config.conv_dim[-1] = last_conv_dim
    seed_everything(config.seed)
    model = LitMultiDimWav2Vec2(config)
    model = model.train()
    model.to('cuda');

    # if scale_logits:
    #   model.model.set_gumbel_temperature(2 / last_conv_dim)
    # else:
    #   model.model.set_gumbel_temperature(2)

    print('model.model.quantizer', model.model.quantizer.scale_logits_in_quantization)

    total_diversity_loss = 0
    total_num_losses = 0

    # Initialize the tqdm progress bar
    progress_bar = tqdm.tqdm(dev_loader, desc="Processing", postfix={'div_loss': 0.0})

    for batch in progress_bar:
      batch = {k: v.to(model.device) for k, v in batch.items()}
      with torch.no_grad():
        num_losses = batch["mask_time_indices"].sum().float()
        output = model.model(**batch)
        total_diversity_loss += output.diversity_loss
        total_num_losses += num_losses

        # Calculate the current average diversity loss
        diversity_loss = total_diversity_loss / total_num_losses

        # Update the progress bar with the current diversity loss
        progress_bar.set_postfix({'div_loss': diversity_loss.item()})

    # Final diversity loss
    diversity_loss = total_diversity_loss / total_num_losses
    if scale_logits:
      div_loss_scaled[last_conv_dim] = diversity_loss.item()
    else:
      div_loss_unscaled[last_conv_dim] = diversity_loss.item()
    print(f"Final diversity_loss: {diversity_loss}")

Seed set to 42


Scale logits: True
model.model.quantizer False


Processing:  19%|█▉        | 44/227 [00:06<00:28,  6.45it/s, div_loss=3.21e-5]


KeyboardInterrupt: 