In [1]:
"""testing the multidim wav2vec model against the reference model"""
import torch
import numpy as np
from lightning.pytorch import seed_everything
import seisbench.data as sbd

from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from transformers import Wav2Vec2Config
from transformers import Wav2Vec2ForPreTraining as RefWav2Vec2ForPreTraining
from seisLM.model.multidim_wav2vec2 import MultiDimWav2Vec2ForPreTraining

data = sbd.STEAD()
waveforms = data.get_waveforms(1265656)
input_values = torch.Tensor(waveforms).unsqueeze(0)





In [2]:
model_output = {}
# for model_name in MODEL_NAMES:
config = Wav2Vec2Config.from_pretrained(
  '/scicore/home/dokman0000/liu0003/projects/seisLM/seisLM/configs/pretrain/model_config_4xdownsample_sinkhorn.json'
)


# config.sinkhorn_quantization_iters = 1

for sinkhorn_quantization_iter in [0, 1, 3, 5]:
  print('------------')
  print(f"Sinkhorn quantization iter: {sinkhorn_quantization_iter}")
  config.sinkhorn_quantization_iters = sinkhorn_quantization_iter
  model = MultiDimWav2Vec2ForPreTraining(config)

  # compute masked indices
  batch_size, num_channels, raw_sequence_length = input_values.shape
  sequence_length = model._get_feat_extract_output_lengths(
    raw_sequence_length).item()

  seed_everything(0)
  mask_time_indices = _compute_mask_indices(
      shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
  )
  sampled_negative_indices = _sample_negative_indices(
      features_shape=(batch_size, sequence_length),
      num_negatives=model.config.num_negatives,
      mask_time_indices=mask_time_indices,
  )
  mask_time_indices = torch.tensor(
    data=mask_time_indices, device=input_values.device, dtype=torch.long)
  sampled_negative_indices = torch.tensor(
    data=sampled_negative_indices, device=input_values.device,
    dtype=torch.long
  )

  with torch.no_grad():
    outputs = model(input_values, mask_time_indices=mask_time_indices,
                    sampled_negative_indices=sampled_negative_indices)


  print(outputs.codevector_perplexity)
  # model_output[f'{model_name}_{model_type}'] = outputs


------------
Sinkhorn quantization iter: 0


Seed set to 0


hidden_states.std: 0.9999865889549255
weight_proj.std: 0.9989061951637268
hidden_states.std after weight_proj: 15.974702835083008
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[  0   0   5   0   8   0   0   7  10   3   0   0   6   8   0   1   1   1
    1   3   0   0  12   0   1   3  17   0   0   0   1   0   2   2   0   1
    0   0   3  26   5   0   0   8  12  16   4   0   2   0  10   4   2   0
    0   0  19   0   0   5  10   0   0   2   0  15   0   0   0   0   5   0
    1   0   0   0   0   0   1  13   0   0   0   2   0   1   0   0   5  19
    6   5   0   4 258   0   3   0   2   0   1   1   1 128   8   0   0   0
    2   8   0   0   0   1   0   0   3   4   3   0   1   1   9   9   9   0
    1   0   0   0   0   3   0   0   1   0   0   3  13   0   0   0   0   0
    7   0  18   6   3   2   1   4   0   0   5   2   0   2   0   2  12  14
    0   4   0   2   4   3   1   1   4   1   0   0   3   0   4   0   2   0
    0   0   1   5   1  10   0   0  

Seed set to 0


hidden_states.std: 0.9999868273735046
weight_proj.std: 1.0001189708709717
hidden_states.std after weight_proj: 16.074504852294922
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[10  9  3  1  2  2  1  4 11 10  0  3  1  2  4  2  3  2  3  2  3  4  1 23
   2  2 13  8  7  7  3  0 14  6  6  3  2  2  1  4  7 13  3  2  2  3  0  8
   4  5  2  2  7  2  5  4  7  3 10  4  6  3  3  2  1  4  3  1 25  1  2  2
   3  2  3  5  5  1  3  2  2  5  3  4  8  4  5  4  5  1  2  3  4  2  1  2
   8  4  4  3  2  0  6  2  2  3  7  1  1  2  1  1  1  6  3  2  9 17  2  3
  11  4  2 10  4  1  5  2  5  1 16  4  1 10  2 10 12  3  3  3  1  3  2  5
   1  5  8 10 10  6  1  2  5  2  1  3 11 27 16  3  5  4  3  2  1 17  6  3
   3  4  2  3  7  2 11  2  2  1  4  5  1 13  2  3  2  2  3  1  3  3 10  6
   4  8 13 15  2  4  5  4  4  1  4  6  0  3  3  2  2 54  1  2  3  3 15  3
   7  2  1  6  6  3  2  4  3  4  1  5  6  2  2  3  3 15 10  2  2 11  1  1
   9  4  3  1  2  7  4  2  4  2  4 

Seed set to 0


hidden_states.std: 0.9999868273735046
weight_proj.std: 1.0001189708709717
hidden_states.std after weight_proj: 16.074504852294922
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[ 6  7  4  6  4  6  3  4  8  7  5  4  3  3  5  3  5  3  4  3  3  5  7 12
   3  2  7  6  6  5  4  5  7  7  5  4  3  3  4  6  5 10  4  4  5  3  3  5
   4  5  6  4  6  4  6  5  7  7  7  4  4  3  5  2  1  5  3  5  8  2  3  4
   6  4  2  5  5  4  3  4  2  5  6  4  5  6  7  5  5  4  2  3  4  4  4  5
   5  4  5  3  4  1  4  3  4  3  7  2  4  7  3  5  6  4  4  7  7 11  2  4
   7  3  4  4  7  3  6  5  5  3 10  4  3  4  3  5  8  3  6  3  3  6  3  6
   1  4  7  6  8  5  4  2  4  4  2  5  3  8  7  5  5  6  4  3  3  8  6  4
   5  6  5  5  6  3  5  3  3  2  7  6  4  9  3  5  3  3  4  1  6  5  8  5
   4  7  7  7  2  6  5  5  6  2  4  5  3  4  4  4  4 14  3  3  4  4  7  6
   5  4  3  5  6  4  7  4  4  6  3  4  6  3  3  3  7  8  6  2  3  7  3  3
   8  3  5  3  4  5  4  5  5  4  5 

Seed set to 0


hidden_states.std: 0.9999868273735046
weight_proj.std: 1.0001189708709717
hidden_states.std after weight_proj: 16.074504852294922
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[ 5  5  5  6  4  5  5  3  7  7  5  4  4  4  6  4  6  5  4  3  4  5  7 10
   6  5  5  5  5  4  5  5  6  6  5  5  3  4  3  7  5  8  5  3  5  5  3  6
   5  4  6  4  4  4  6  6  5  5  7  5  6  4  4  3  3  5  3  6  7  3  4  4
   6  4  4  5  5  4  4  4  2  5  6  5  5  5  7  6  5  3  2  3  4  5  2  5
   5  4  5  3  4  3  5  4  4  3  5  3  5  6  4  5  5  3  4  6  6  5  3  5
   4  3  4  5  6  4  5  4  4  4  9  5  4  4  4  4  7  4  6  4  3  6  3  5
   2  4  6  6  6  6  3  4  3  5  5  5  2  5  6  5  5  7  5  3  4  6  5  6
   5  6  5  6  6  3  4  3  3  3  7  6  6  6  3  5  6  5  4  4  3  4  5  5
   6  6  5  6  2  6  5  5  5  4  4  5  4  4  4  4  6  8  5  6  3  4  5  6
   5  4  3  5  4  4  7  5  4  5  3  4  7  3  3  4  5  7  6  2  3  5  4  3
   6  3  6  3  5  4  5  5  6  4  6 

In [3]:
model_output = {}
# for model_name in MODEL_NAMES:
config = Wav2Vec2Config.from_pretrained(
  '/scicore/home/dokman0000/liu0003/projects/seisLM/seisLM/configs/pretrain/model_config_4xdownsample_sinkhorn.json'
)
config.scale_logits_in_quantization = True

for sinkhorn_quantization_iter in [0, 1, 3, 5]:
  print('------------')
  print(f"Sinkhorn quantization iter: {sinkhorn_quantization_iter}")
  config.sinkhorn_quantization_iters = sinkhorn_quantization_iter
  model = MultiDimWav2Vec2ForPreTraining(config)

  # compute masked indices
  batch_size, num_channels, raw_sequence_length = input_values.shape
  sequence_length = model._get_feat_extract_output_lengths(
    raw_sequence_length).item()

  seed_everything(0)
  mask_time_indices = _compute_mask_indices(
      shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
  )
  sampled_negative_indices = _sample_negative_indices(
      features_shape=(batch_size, sequence_length),
      num_negatives=model.config.num_negatives,
      mask_time_indices=mask_time_indices,
  )
  mask_time_indices = torch.tensor(
    data=mask_time_indices, device=input_values.device, dtype=torch.long)
  sampled_negative_indices = torch.tensor(
    data=sampled_negative_indices, device=input_values.device,
    dtype=torch.long
  )

  with torch.no_grad():
    outputs = model(input_values, mask_time_indices=mask_time_indices,
                    sampled_negative_indices=sampled_negative_indices)


  print(outputs.codevector_perplexity)
  # model_output[f'{model_name}_{model_type}'] = outputs

------------
Sinkhorn quantization iter: 0


Seed set to 0


hidden_states.std: 0.9999868273735046
weight_proj.std: 1.0001189708709717
hidden_states.std after weight_proj: 16.074504852294922
hidden_states.std after scaling: 1.0046565532684326
denom: 16.0
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[ 6  5  4  3  6  3  1  5  9 10  1  5  4  2  6  1  4  1  1  7  2  3  6  9
   4  1  4  6  5  8  1  8  8  6  9  1  4  5  6  3 10  7  4  3  7  4  6  7
   2  6  2  1  6  6  6  2  5  7  3  8  8  2  4  4  4  0  3  4 13  3  1  3
   3  8  6  7  8  1  4  5  5  0  5  5  7  8  3  4  7  4  3  7  3  3  3  4
   7  3  1  3  2  4  5  2  2  5  6  1  4  6  3  7  4  4  2  0  9 10  3  1
   3  1  3  6  3  3  3 11  2  3  6  6  1  7  1 17  9  4  2  3  0  6  0  5
   3  5  6  4  9  6  8  5  7  2  3  8 10  9  4  7  7  4  3  7  1  6  5  1
   3  3  1  2  5  4  6  2  4  8  3  6  3  5  2  4  4  3  2  1  4  3  6  5
   4  7  7  8  4  3  4  3  7  5  4  8  2  3  5  3  4 14  5  1  2  5  7  1
   6  5  3  3  7  5  4  5 13  5  3 10  6  5  

Seed set to 0


hidden_states.std: 0.9999868273735046
weight_proj.std: 1.0001189708709717
hidden_states.std after weight_proj: 16.074504852294922
hidden_states.std after scaling: 1.0046565532684326
denom: 16.0
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[ 3  6  4  2  7  4  2  3  7  8  5  5  7  3  7  3  4  1  1  6  2  6  7  5
   5  4  4  6  4  6  3  6  6  5  9  2  4  8  6  3  5  3  5  3  7  8  3  6
   2  6  5  3  6  6  2  4  4  7  3  6 10  4  4  7  4  0  3  9  3  7  2  3
   2  8  6  3  5  5  4  5  9  2  5  5  4  7  6  4  7  8  3  8  2  3  6  3
   6  4  3  5  4  4  5  4  3  4  4  4  7  5  5  5  7  4  4  3  6  7  4  2
   3  1  3  6  2  6  4  9  3  4  5  5  4  5  1  7  8  5  3  3  4  6  4  5
   3  8  6  3  4  4  8  6  4  4  6  8  9  7  2  5  5  4  6  7  2  5  2  2
   6  2  4  2  3  5  3  2  3  7  6  5  4  4  3  3  5  4  3  5  5  5  5  6
   7  5  5  5  4  5  3  4  7  4  4  6  6  2  5  4  5  5  4  2  3  5  5  4
   4  6  4  4  7  4  5  5  8  2  5  6  3  6  

Seed set to 0


hidden_states.std: 0.9999868273735046
weight_proj.std: 1.0001189708709717
hidden_states.std after weight_proj: 16.074504852294922
hidden_states.std after scaling: 1.0046565532684326
denom: 16.0
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[ 3  6  4  2  7  4  2  3  7  8  5  5  7  3  7  3  4  1  1  6  2  6  7  5
   5  4  4  6  4  6  3  6  6  5  9  3  4  8  6  3  5  3  5  3  7  8  3  6
   2  6  5  3  6  5  2  4  4  7  3  5 10  4  4  7  4  0  3  9  3  7  2  3
   2  8  6  3  5  5  4  4  9  2  5  6  4  7  6  4  7  8  3  8  2  3  6  3
   6  4  3  5  4  4  5  4  3  4  4  4  7  5  5  5  7  4  4  3  6  7  4  2
   3  1  3  6  2  6  4  9  3  4  5  5  4  5  1  7  8  5  3  3  4  6  4  5
   3  8  6  3  4  4  8  6  4  4  6  8  9  7  2  5  5  4  6  7  2  5  2  2
   6  2  4  3  3  5  3  2  3  7  6  5  4  4  3  3  5  4  3  5  5  5  5  6
   7  5  5  5  4  5  3  4  7  4  4  6  6  2  5  4  5  5  4  2  3  5  5  4
   4  6  4  4  7  4  5  5  8  2  5  6  3  6  

Seed set to 0


hidden_states.std: 0.9999868273735046
weight_proj.std: 1.0001189708709717
hidden_states.std after weight_proj: 16.074504852294922
hidden_states.std after scaling: 1.0046565532684326
denom: 16.0
codevector_probs.shape [B * L, G, V]: torch.Size([1499, 2, 320])
codevector_probs sum over B*L [[ 3  6  4  2  7  4  2  3  7  8  5  5  7  3  7  3  4  1  1  6  2  6  7  5
   5  4  4  6  4  6  3  6  6  5  9  3  4  8  6  3  5  3  5  3  7  8  3  6
   2  6  5  3  6  5  2  4  4  7  3  5 10  4  4  7  4  0  3  9  3  7  2  3
   2  8  6  3  5  5  4  4  9  2  5  6  4  7  6  4  7  8  3  8  2  3  6  3
   6  4  3  5  4  4  5  4  3  4  4  4  7  5  5  5  7  4  4  3  6  7  4  2
   3  1  3  6  2  6  4  9  3  4  5  5  4  5  1  7  8  5  3  3  4  6  4  5
   3  8  6  3  4  4  8  6  4  4  6  8  9  7  2  5  5  4  6  7  2  5  2  2
   6  2  4  3  3  5  3  2  3  7  6  5  4  4  3  3  5  4  3  5  5  5  5  6
   7  5  5  5  4  5  3  4  7  4  4  6  6  2  5  4  5  5  4  2  3  5  5  4
   4  6  4  4  7  4  5  5  8  2  5  6  3  6  