In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [19]:
from collections import defaultdict, Counter
import pickle
from pathlib import Path

import seaborn as sns
import torch
from tqdm.auto import tqdm, trange
import transformers

from src.datasets.speech_equivalence import \
    SpeechEquivalenceDataset, SpeechHiddenStateDataset, make_timit_equivalence_dataset
from src.models.integrator import ContrastiveEmbeddingModel, ContrastiveEmbeddingModelConfig, prepare_dataset
from src.utils.timit import load_or_prepare_timit_corpus

In [17]:
model_name = "facebook/wav2vec2-base"

equivalence_classer = "phoneme_within_word_prefix"
num_frames_per_phoneme = 1

layer = 6
output_dim = 32

equiv_dataset_path = f"data/timit_equiv_{equivalence_classer}_{num_frames_per_phoneme}.pkl"
output_dir = f"out/ce_model_{equivalence_classer}_{layer}_{output_dim}"

## Prepare equivalence class dataset

In [5]:
tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("charsiu/tokenizer_en_cmu")
feature_extractor = transformers.Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
model: transformers.Wav2Vec2Model = transformers.Wav2Vec2Model.from_pretrained(model_name)
model = model.to("cuda")



In [7]:
dataset = load_or_prepare_timit_corpus("data/timit_phoneme", "data/timit_raw",
                                       processor)

def add_indices(item, idx):
    item["idx"] = idx
    return item
dataset = dataset.map(add_indices, batched=True, batch_size=2000, with_indices=True)

In [13]:
dev_dataset = dataset["train"].select(range(1000))

In [14]:
equiv_dataset = make_timit_equivalence_dataset(
    f"timit_phoneme/{equivalence_classer}",
    dev_dataset, model,
    equivalence_classer,
    num_frames_per_phoneme=num_frames_per_phoneme)

Extracting hidden states:   0%|          | 0/1000 [00:00<?, ? examples/s]

Computing start frames:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [21]:
equiv_dataset

SpeechEquivalenceDataset(timit_phoneme/phoneme_within_word_prefix, 8879 classes, 31712 instances, with SpeechHiddenStateDataset(facebook/wav2vec2-base, 1000 items, 150670 frames, 13 layers, 768 hidden size))

In [None]:
# TODO compute how many positive examples each Q lines up. we want to make sure we have a minimal
# number of positive examples for each Q, even the sparse word-level ones

In [18]:
# Pick a max length that accommodates the majority of the samples, excluding outlier lengths
evident_lengths = equiv_dataset.lengths
target_length = int(torch.quantile(evident_lengths, 0.95).item())
sns.displot(evident_lengths.numpy(), kde=True)
target_length

In [18]:
dataset = prepare_dataset(equiv_dataset, target_length)

  0%|          | 0/31712 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
Path(output_dir).mkdir(exist_ok=True, parents=True)
dataset.save_to_disk(Path(output_dir) / "dataset")

In [None]:
ce_model = ContrastiveEmbeddingModel(
    input_dim=equiv_dataset.hidden_state_dataset.hidden_dim,
    hidden_dim=32,
    output_dim=output_dim, tau=0.1)

In [None]:
training_args = transformers.TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    num_train_epochs=2,
    per_device_train_batch_size=32,
    save_steps=100,
    eval_steps=100,
    save_total_limit=5,
    logging_steps=10,
    logging_dir=f"{output_dir}/logs",
    evaluation_strategy="steps",
    logging_first_step=True,
    load_best_model_at_end=True,
    greater_is_better=False,
    remove_unused_columns=False,
)

dataset_split = dev_dataset.train_test_split(test_size=0.1, shuffle=True)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
trainer = transformers.Trainer(
    model=ce_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # compute_metrics=compute_metrics,
    # data_collator=MyCollator(max_length),
    args=training_args)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer.train()

Step,Training Loss,Validation Loss
100,-0.5319,-0.490023
200,-0.5539,-0.592076
300,-0.6521,-0.705454
400,-0.7003,-0.837613
500,-1.1271,-0.988161
600,-1.0692,-1.149871
700,-1.3069,-1.31905
800,-1.0717,-1.49235
900,-1.2993,-1.669061
1000,-2.1218,-1.834337


TrainOutput(global_step=1784, training_loss=-1.5648202286707447, metrics={'train_runtime': 755.1656, 'train_samples_per_second': 75.586, 'train_steps_per_second': 2.362, 'total_flos': 0.0, 'train_loss': -1.5648202286707447, 'epoch': 2.0})