In [1]:
import torch
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from datasets import load_dataset
from lightning.pytorch import seed_everything
from transformers import Wav2Vec2Config
import seisbench.data as sbd

from transformers import Wav2Vec2ForPreTraining
# from seisLM.model.wav2vec2 import Wav2Vec2ForPreTraining

data = sbd.STEAD()

config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")
seed_everything(0)
model = Wav2Vec2ForPreTraining(config)

waveforms = data.get_waveforms(1265656)
input_values = torch.Tensor(waveforms[0]).unsqueeze(0)

# compute masked indices
batch_size, raw_sequence_length = input_values.shape
sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()

seed_everything(0)
mask_time_indices = _compute_mask_indices(
    shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
)
sampled_negative_indices = _sample_negative_indices(
    features_shape=(batch_size, sequence_length),
    num_negatives=model.config.num_negatives,
    mask_time_indices=mask_time_indices,
)
mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
sampled_negative_indices = torch.tensor(
    data=sampled_negative_indices, device=input_values.device, dtype=torch.long
)

with torch.no_grad():
    outputs = model(input_values, 
                    mask_time_indices=mask_time_indices,
                    sampled_negative_indices=sampled_negative_indices)

print(outputs.projected_states.mean())
print(f'projected_states mean {outputs.projected_states.mean()}')
print(f'projected_quantized_states mean {outputs.projected_quantized_states.mean()}')
print(f'codevector_perplexity mean {outputs.codevector_perplexity.mean()}')


Seed set to 0
Seed set to 0


tensor(-0.0238)
projected_states mean -0.02380690723657608
projected_quantized_states mean 0.017036495730280876
codevector_perplexity mean 11.194181442260742




In [2]:
sampled_negative_indices

tensor([[[ 0,  0,  0,  ...,  0,  0,  0],
         [11, 11,  2,  ..., 10, 11, 11],
         [11, 11, 11,  ...,  1,  1, 11],
         ...,
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0]]])