In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
from collections import defaultdict, Counter
from pathlib import Path

import torch
from tqdm.auto import tqdm, trange
import transformers

from src.models.integrator import ContrastiveEmbeddingModel, prepare_dataset
from src.utils.timit import load_or_prepare_timit_corpus

In [4]:
tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("charsiu/tokenizer_en_cmu")
feature_extractor = transformers.Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
model: transformers.Wav2Vec2Model = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model = model.to("cuda")



In [9]:
dataset = load_or_prepare_timit_corpus("data/timit_phoneme", "data/timit_raw",
                                       processor)

def add_indices(item, idx):
    item["idx"] = idx
    return item
dataset = dataset.map(add_indices, batched=True, batch_size=2000, with_indices=True)

In [13]:
# Collect this many frames per phoneme at maximum
num_frames_per_phoneme = 1

# Compute frame-level encodings for each example
flat_idxs = []
frames_by_item = {}
frame_states = []

# Collect an index of frames in the dataset by phoneme within matched word type.
# We will use this to equivalence-class the model frames later on.
equivalence_classers = {
    "phoneme_within_word": lambda word, i: (tuple(phone["phone"] for phone in word), i),
    "phoneme_within_word_prefix": lambda word, i: tuple(phone["phone"] for phone in word[:i + 1]),
    "phoneme_within": lambda word, i: word[i]["phone"],
}
frame_groups = {classer: defaultdict(list) for classer in equivalence_classers}

def process(item, idx):
    with torch.no_grad():
        output = model(output_hidden_states=True,
                       input_values=torch.tensor(item["input_values"]).unsqueeze(0).to(model.device))
        
    # num_layers * sequence_length * hidden_size
    batch_hidden = torch.stack(output.hidden_states).squeeze(1).cpu()

    flat_idx_offset = len(flat_idxs)
    flat_idxs.extend([(idx, i) for i in range(batch_hidden.shape[1])])
    frames_by_item[idx] = (flat_idx_offset, len(flat_idxs))
    frame_states.append(batch_hidden)

    # Now align and store frame metadata
    compression_ratio = batch_hidden.shape[1] / len(item["input_values"])
    for word in item["word_phonemic_detail"]:
        word_str = tuple(phone["phone"] for phone in word)
        word_start = int(word[0]["start"] * compression_ratio)
        word_end = int(word[-1]["stop"] * compression_ratio)

        for j, phone in enumerate(word):
            word_prefix = word_str[:j + 1]

            phone_str = phone["phone"]
            phone_start = int(phone["start"] * compression_ratio)
            phone_end = int(phone["stop"] * compression_ratio)

            ks = list(range(phone_start, phone_end + 1))
            if len(ks) > num_frames_per_phoneme:
                ks = ks[-num_frames_per_phoneme:]
            for k in ks:
                for classer, fn in equivalence_classers.items():
                    class_key = fn(word, j)
                    frame_groups[classer][class_key].append((idx, k))

    return None

dev_dataset = dataset["train"].select(range(1000))
dev_dataset.map(process, with_indices=True)

frame_states = torch.cat(frame_states, dim=1)
assert frame_states.shape[1] == len(flat_idxs)

# num_frames * num_layers * hidden_size
frame_states = frame_states.transpose(0, 1)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [14]:
Qs = {k: torch.zeros(len(flat_idxs), dtype=torch.long) - 1
      for k in frame_groups.keys()}
group_idxs = {k: {class_key: idx for idx, class_key in enumerate(frame_groups[k].keys())}
              for k in frame_groups.keys()}
flat_idxs_rev = {idx: i for i, idx in enumerate(flat_idxs)}

for classer, groups in frame_groups.items():
    for class_key, group in tqdm(groups.items(), desc=f"Grouping by {classer}"):
        for idx, k in group:
            Qs[classer][flat_idxs_rev[idx, k]] = group_idxs[classer][class_key]

Grouping by phoneme_within_word:   0%|          | 0/16722 [00:00<?, ?it/s]

Grouping by phoneme_within_word_prefix:   0%|          | 0/8879 [00:00<?, ?it/s]

Grouping by phoneme_within:   0%|          | 0/40 [00:00<?, ?it/s]

In [54]:
# For each frame, store the preceding frame at which the word event began.
start_reference = {
    "phoneme_within_word": "word",
    "phoneme_within_word_prefix": "word",
    "phoneme_within": "phoneme",
}
Ss = {k: torch.zeros(len(flat_idxs), dtype=torch.long) - 1
      for k in start_reference.values()}

def compute_S_word(item, idx):
    flat_idx_offset, flat_idx_end = frames_by_item[idx]
    num_frames = flat_idx_end - flat_idx_offset
    compression_ratio = num_frames / len(item["input_values"])

    for word in item["word_phonemic_detail"]:
        word_str = tuple(phone["phone"] for phone in word)
        word_start = int(word[0]["start"] * compression_ratio)
        word_end = int(word[-1]["stop"] * compression_ratio)

        for j in range(word_start, word_end + 1):
            Ss["word"][flat_idx_offset + j] = flat_idx_offset + word_start

    return None

def compute_S_phoneme(item, idx):
    flat_idx_offset, flat_idx_end = frames_by_item[idx]
    num_frames = flat_idx_end - flat_idx_offset
    compression_ratio = num_frames / len(item["input_values"])

    for word in item["word_phonemic_detail"]:
        for phone in word:
            phone_str = phone["phone"]
            phone_start = int(phone["start"] * compression_ratio)
            phone_end = int(phone["stop"] * compression_ratio)

            for j in range(phone_start, phone_end + 1):
                Ss["phoneme"][flat_idx_offset + j] = flat_idx_offset + phone_start

    return None

start_reference_fns = {
    "word": compute_S_word,
    "phoneme": compute_S_phoneme,
}
for k, fn in start_reference_fns.items():
    dev_dataset.map(fn, with_indices=True)

# If Q is set, then S should be set
# (Q != -1) => (S != -1)
# assert ((Q_phoneme == -1) | (S_phoneme != -1)).all()

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [55]:
# Sanity check: no crazy long events
for start_reference, S in Ss.items():
    evident_lengths = torch.arange(len(S)) - S
    max_evident_length = evident_lengths[S != -1].max()
    assert max_evident_length < 100, f"{start_reference} has event with length {max_evident_length} frames"

In [None]:
# TODO compute how many positive examples each Q lines up. we want to make sure we have a minimal
# number of positive examples for each Q, even the sparse word-level ones

In [16]:
layer = 6
output_dim = 32

eval_classer = "phoneme_within_word" #phoneme_within_word_prefix"
output_dir = f"out/ce_model_{eval_classer}_{layer}_{output_dim}"

In [27]:
Q = Qs[eval_classer]
S = Ss[start_reference[eval_classer]]

# Compute ideal max_length based on S.
evident_lengths = torch.arange(len(S)) - S
evident_lengths = evident_lengths[S != -1]
max_length = evident_lengths.max().item()

RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

In [23]:
S[S != -1], len(S[S != -1])

(tensor([ 9,  9,  9,  ..., 88, 88, 88]), 129037)

In [19]:
evident_lengths

tensor([     0,      1,      2,  ..., 150573, 150574, 150575])

In [18]:
max_length
# TODO max_length is too big -- figure out why tomorrow.

150583

In [18]:
dataset = prepare_dataset(frame_states[:, layer, :], Q, S, max_length)

  0%|          | 0/31712 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
Path(output_dir).mkdir(exist_ok=True, parents=True)
dataset.save_to_disk(Path(output_dir) / "dataset")

In [None]:
ce_model = ContrastiveEmbeddingModel(
    input_dim=frame_states.shape[-1],
    hidden_dim=32,
    output_dim=output_dim, tau=0.1)

In [None]:
training_args = transformers.TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    num_train_epochs=2,
    per_device_train_batch_size=32,
    save_steps=100,
    eval_steps=100,
    save_total_limit=5,
    logging_steps=10,
    logging_dir=f"{output_dir}/logs",
    evaluation_strategy="steps",
    logging_first_step=True,
    load_best_model_at_end=True,
    greater_is_better=False,
    remove_unused_columns=False,
)

dataset_split = dataset.train_test_split(test_size=0.1, shuffle=True)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
trainer = transformers.Trainer(
    model=ce_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # compute_metrics=compute_metrics,
    # data_collator=MyCollator(max_length),
    args=training_args)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer.train()

Step,Training Loss,Validation Loss
100,-0.5319,-0.490023
200,-0.5539,-0.592076
300,-0.6521,-0.705454
400,-0.7003,-0.837613
500,-1.1271,-0.988161
600,-1.0692,-1.149871
700,-1.3069,-1.31905
800,-1.0717,-1.49235
900,-1.2993,-1.669061
1000,-2.1218,-1.834337


TrainOutput(global_step=1784, training_loss=-1.5648202286707447, metrics={'train_runtime': 755.1656, 'train_samples_per_second': 75.586, 'train_steps_per_second': 2.362, 'total_flos': 0.0, 'train_loss': -1.5648202286707447, 'epoch': 2.0})