In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [88]:
import torch
import torch.nn as nn
import torch.nn.functional as Fn
from torch.utils.data import DataLoader, TensorDataset
import random
from tqdm.auto import tqdm
from collections import defaultdict

In [129]:
class RNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(RNNModel, self).__init__()
        self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, lengths):
        # TODO padding is on the LEFT here, make sure RNN reflects this
        packed_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.rnn(packed_input)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        output = self.fc(output)
        return output
    

class ContrastiveEmbeddingObjective(nn.Module):
    def __init__(self, tau=0.1):
        super(ContrastiveEmbeddingObjective, self).__init__()
        self.tau = tau

    def forward(self, embeddings, pos_embeddings, neg_embeddings):
        pos_dist = Fn.cosine_similarity(embeddings, pos_embeddings, dim=1)
        neg_dist = Fn.cosine_similarity(embeddings, neg_embeddings, dim=1)

        pos_loss = -torch.log(torch.exp(pos_dist / self.tau)).mean()
        neg_loss = -torch.log(torch.exp(-neg_dist / self.tau)).mean()

        loss = pos_loss + neg_loss

        return loss
    

class ContrastiveEmbeddingModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, tau=0.1):
        super(ContrastiveEmbeddingModel, self).__init__()
        self.rnn = RNNModel(input_dim, hidden_dim, output_dim)
        self.tau = tau

    def forward(self, batch):
        example_batch, example_lengths, pos_batch, pos_lengths, neg_batch, neg_lengths = batch
        embeddings, pos_embeddings, neg_embeddings = self.compute_batch_embeddings(example_batch, example_lengths, pos_batch, pos_lengths, neg_batch, neg_lengths)
        return ContrastiveEmbeddingObjective(tau=self.tau)(embeddings, pos_embeddings, neg_embeddings)
        
    def compute_batch_embeddings(self, example_batch, example_lengths, pos_batch, pos_lengths, neg_batch, neg_lengths):
        # Compute embeddings
        embeddings = self.rnn(example_batch, example_lengths)
        pos_embeddings = self.rnn(pos_batch, pos_lengths)
        neg_embeddings = self.rnn(neg_batch, neg_lengths)

        # Gather final embedding of each sequence
        embeddings = torch.gather(embeddings, 1, (example_lengths - 1).reshape(-1, 1, 1).expand(-1, 1, embeddings.shape[-1])).squeeze(1)
        pos_embeddings = torch.gather(pos_embeddings, 1, (pos_lengths - 1).reshape(-1, 1, 1).expand(-1, 1, pos_embeddings.shape[-1])).squeeze(1)
        neg_embeddings = torch.gather(neg_embeddings, 1, (neg_lengths - 1).reshape(-1, 1, 1).expand(-1, 1, neg_embeddings.shape[-1])).squeeze(1)

        return embeddings, pos_embeddings, neg_embeddings


def get_sequence(F, start_index, end_index, max_length):
    if end_index - start_index + 1 > max_length:
        start_index = end_index - max_length + 1
    sequence = F[start_index:end_index + 1]
    
    # Pad on right if necessary
    if len(sequence) < max_length:
        pad_size = max_length - len(sequence)
        padding = torch.zeros(pad_size, F.shape[1])
        sequence = torch.cat((sequence, padding), dim=0)
    
    return sequence


def prepare_batches(F, Q, S, max_length, batch_size=32):
    dataset = []
    assert F.shape[0] == Q.shape[0] == S.shape[0]
    n_F = F.size(0)

    lengths = torch.arange(n_F) - S
    lengths = torch.minimum(lengths, torch.tensor(max_length))
    # TODO this is just a hack
    lengths[lengths == 0] = 1

    for i in range(n_F):
        pos_indices = (Q == Q[i]).nonzero(as_tuple=True)[0]
        neg_indices = (Q != Q[i]).nonzero(as_tuple=True)[0]

        if len(pos_indices) > 1 and len(neg_indices) > 0:
            pos_indices = pos_indices[pos_indices != i]
            pos_idx = random.choice(pos_indices)
            neg_idx = random.choice(neg_indices)

            # Extract sequences
            example_seq = get_sequence(F, S[i], i, max_length)
            pos_seq = get_sequence(F, S[pos_idx], pos_idx, max_length)
            neg_seq = get_sequence(F, S[neg_idx], neg_idx, max_length)

            dataset.append((example_seq, lengths[i],
                            pos_seq, lengths[pos_idx],
                            neg_seq, lengths[neg_idx]))

    return DataLoader(TensorDataset(
        # example frames and lengths
        torch.stack([x[0] for x in dataset]), 
        torch.stack([x[1] for x in dataset]), 

        # positive frames and lengths
        torch.stack([x[2] for x in dataset]),
        torch.stack([x[3] for x in dataset]),

        # negative frames and lengths
        torch.stack([x[4] for x in dataset]),
        torch.stack([x[5] for x in dataset])),
        batch_size=batch_size, shuffle=True)


def compute_batched_rnn_loss(model, data_loader, tau=0.1):
    total_loss = 0
    total_batches = 0

    for batch in data_loader:
        loss = model(batch)
        total_loss += loss.item()
        total_batches += 1

    return total_loss / total_batches


# Example usage
n_F, d = 100, 4  # Example dimensions
F = torch.randn(n_F, d) * 3  # Random frame features
Q = torch.randint(0, 10, (n_F,))  # Random frame matches
S = torch.maximum(torch.tensor(0), torch.arange(n_F) - torch.randint(1, 10, (n_F,)))  # Random span indices
max_length = 20  # Maximum sequence length for RNN

# rnn_model = RNNModel(input_dim=d, hidden_dim=256, output_dim=d)
model = ContrastiveEmbeddingModel(input_dim=d, hidden_dim=256, output_dim=d, tau=0.1)
# data_loader = prepare_batches(F, Q, S, max_length, batch_size=32)
# loss = compute_batched_rnn_loss(model, data_loader)
# print(loss)

# Build a test batch
sample_idx, pos_sample_idx, neg_sample_idx = 37, 23, 85
sample_length, pos_sample_length, neg_sample_length = 10, 8, 12
example_batch = get_sequence(F, S[sample_idx], sample_idx, max_length).unsqueeze(0)
pos_sample_batch = get_sequence(F, S[pos_sample_idx], pos_sample_idx, max_length).unsqueeze(0)
neg_sample_batch = get_sequence(F, S[neg_sample_idx], neg_sample_idx, max_length).unsqueeze(0)
example_lengths = torch.tensor([sample_length])
pos_sample_lengths = torch.tensor([pos_sample_length])
neg_sample_lengths = torch.tensor([neg_sample_length])
with torch.no_grad():
    embeddings, pos_embeddings, neg_embeddings = model.compute_batch_embeddings(example_batch, example_lengths, pos_sample_batch, pos_sample_lengths, neg_sample_batch, neg_sample_lengths)
# Manually compute embeddings
def compute_embedding_single(model, x, length):
    return model.rnn.fc(model.rnn.rnn(x[:, :length, :])[0][:, -1, :])
with torch.no_grad():
    embeddings_manual = compute_embedding_single(model, example_batch, example_lengths)
    pos_embeddings_manual = compute_embedding_single(model, pos_sample_batch, pos_sample_lengths)
    neg_embeddings_manual = compute_embedding_single(model, neg_sample_batch, neg_sample_lengths)

    # Check that the embeddings are the same
    print(embeddings)
    print(embeddings_manual)
    print("//")
    print(pos_embeddings)
    print(pos_embeddings_manual)
    print("//")
    print(neg_embeddings)
    print(neg_embeddings_manual)
    torch.testing.assert_close(embeddings, embeddings_manual)
    torch.testing.assert_close(pos_embeddings, pos_embeddings_manual)
    torch.testing.assert_close(neg_embeddings, neg_embeddings_manual)


tensor([[-0.0148, -0.0528, -0.0368,  0.0270]])
tensor([[-0.0148, -0.0528, -0.0368,  0.0270]])
//
tensor([[-0.0667, -0.0486, -0.0589,  0.0061]])
tensor([[-0.0667, -0.0486, -0.0589,  0.0061]])
//
tensor([[-0.0149, -0.0511, -0.0373,  0.0261]])
tensor([[-0.0149, -0.0511, -0.0373,  0.0261]])


In [9]:
from src.utils.timit import load_or_prepare_timit_corpus
from src.models.frame_level import LexicalAccessDataCollator
import transformers

In [4]:
tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("charsiu/tokenizer_en_cmu")
feature_extractor = transformers.Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [130]:
model = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")



In [45]:
dataset = load_or_prepare_timit_corpus("data/timit_phoneme", "data/timit_raw",
                                       processor)

def add_indices(item, idx):
    item["idx"] = idx
    return item
dataset = dataset.map(add_indices, batched=True, batch_size=2000, with_indices=True)

In [50]:
# Compute frame-level encodings for each example
flat_idxs = []
frame_states = []

def collate(features):
    input_features = [{"input_values": feature["input_values"]} for feature in features]
    # For classification
    # label_features = [feature["labels"] for feature in features]

    ret = processor.pad(
        input_features,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors="pt",
        return_attention_mask=True,
    )

    if "idx" in features[0]:
        ret["idx"] = torch.tensor([feature["idx"] for feature in features])

    return ret


dev_dataset = dataset["train"].select(range(5))
dataloader = DataLoader(dev_dataset, batch_size=8, collate_fn=collate, shuffle=False)

# Collect an index of frames in the dataset by phoneme within matched word type.
# We will use this to equivalence-class the model frames later on.
phoneme_within_word: dict[tuple[str, int], list[tuple[int, int]]] = {}

for batch in tqdm(dataloader):
    idxs = batch.pop("idx")
    with torch.no_grad():
        output = model(output_hidden_states=True, **batch)

    # num_layers * batch_size * sequence_length * hidden_size
    batch_hidden = torch.stack(output.hidden_states)
    
    # Extract just non-padding values for each example and concatenate
    batch_num_samples = batch.attention_mask.sum(1)
    batch_num_frames = model._get_feat_extract_output_lengths(batch_num_samples).tolist()
    
    for i, (idx, num_frames) in enumerate(zip(idxs, batch_num_frames)):
        frame_states.append(batch_hidden[:, i, :num_frames, :])
        flat_idxs.extend([(idx, i) for i in range(num_frames)])

  0%|          | 0/1 [00:00<?, ?it/s]

In [61]:
list(dev_dataset[0].keys())

['file',
 'audio',
 'text',
 'phonetic_detail',
 'word_detail',
 'dialect_region',
 'sentence_type',
 'speaker_id',
 'id',
 'phonemic_detail',
 'word_phonetic_detail',
 'word_phonemic_detail',
 'input_values',
 'phone_targets',
 'idx']

In [62]:
dev_dataset[0]["word_phonemic_detail"]

[[{'phone': 'SH', 'start': 3050, 'stop': 4559},
  {'phone': 'IH', 'start': 4559, 'stop': 5723}],
 [{'phone': 'HH', 'start': 5723, 'stop': 6642},
  {'phone': 'EH', 'start': 6642, 'stop': 8772},
  {'phone': 'D', 'start': 8772, 'stop': 9190},
  {'phone': 'JH', 'start': 9190, 'stop': 10337}],
 [{'phone': 'JH', 'start': 9190, 'stop': 10337},
  {'phone': 'IH', 'start': 10337, 'stop': 11517}],
 [{'phone': 'D', 'start': 11517, 'stop': 12640},
  {'phone': 'AH', 'start': 12640, 'stop': 14714},
  {'phone': 'K', 'start': 14714, 'stop': 16334}],
 [{'phone': 'S', 'start': 16334, 'stop': 18088},
  {'phone': 'UW', 'start': 18088, 'stop': 20417},
  {'phone': 'T', 'start': 20417, 'stop': 21199}],
 [{'phone': 'AH', 'start': 21199, 'stop': 22560},
  {'phone': 'N', 'start': 21199, 'stop': 22560}],
 [{'phone': 'G', 'start': 22560, 'stop': 23271},
  {'phone': 'R', 'start': 23271, 'stop': 24229},
  {'phone': 'IH', 'start': 24229, 'stop': 25566},
  {'phone': 'S', 'start': 25566, 'stop': 27156},
  {'phone': 'IH

In [102]:
# Compute frame-level encodings for each example
flat_idxs = []
frames_by_item = {}
frame_states = []

# Collect an index of frames in the dataset by phoneme within matched word type.
# We will use this to equivalence-class the model frames later on.
phoneme_within_word: dict[tuple[tuple[str, ...], int], list[tuple[int, int]]] = defaultdict(list)
phoneme_within_word_prefix: dict[tuple[str, ...], list[tuple[int, int]]] = defaultdict(list)
phoneme_within: dict[str, list[tuple[int, int]]] = defaultdict(list)

def process(item, idx):
    with torch.no_grad():
        output = model(output_hidden_states=True,
                       input_values=torch.tensor(item["input_values"]).unsqueeze(0))
        
    # num_layers * sequence_length * hidden_size
    batch_hidden = torch.stack(output.hidden_states).squeeze(1)

    flat_idx_offset = len(flat_idxs)
    flat_idxs.extend([(idx, i) for i in range(batch_hidden.shape[1])])
    frames_by_item[idx] = (flat_idx_offset, len(flat_idxs))
    frame_states.append(batch_hidden)

    # Now align and store frame metadata
    compression_ratio = batch_hidden.shape[1] / len(item["input_values"])
    for word in item["word_phonemic_detail"]:
        word_str = tuple(phone["phone"] for phone in word)
        word_start = int(word[0]["start"] * compression_ratio)
        word_end = int(word[-1]["stop"] * compression_ratio)

        for j, phone in enumerate(word):
            word_prefix = word_str[:j + 1]

            phone_str = phone["phone"]
            phone_start = int(phone["start"] * compression_ratio)
            phone_end = int(phone["stop"] * compression_ratio)

            for k in range(phone_start, phone_end + 1):
                phoneme_within_word[word_str, j].append((idx, k))
            for k in range(phone_start, phone_end + 1):
                phoneme_within_word_prefix[word_prefix].append((idx, k))
            for k in range(phone_start, phone_end + 1):
                phoneme_within[phone_str].append((idx, k))

    return None

dev_dataset = dataset["train"].select(range(5))
dev_dataset.map(process, with_indices=True)

frame_states = torch.cat(frame_states, dim=1)
assert frame_states.shape[1] == len(flat_idxs)

# num_frames * num_layers * hidden_size
frame_states = frame_states.transpose(0, 1)

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [90]:
Q_full = torch.zeros(len(flat_idxs), dtype=torch.long) - 1
flat_idxs_rev = {idx: i for i, idx in enumerate(flat_idxs)}

for i, (word, j) in enumerate(phoneme_within_word):
    for idx, k in phoneme_within_word[word, j]:
        Q_full[flat_idxs_rev[idx, k]] = i

In [91]:
Q_prefix = torch.zeros(len(flat_idxs), dtype=torch.long) - 1
flat_idxs_rev = {idx: i for i, idx in enumerate(flat_idxs)}

for i, prefix in enumerate(phoneme_within_word_prefix):
    for idx, k in phoneme_within_word_prefix[prefix]:
        Q_prefix[flat_idxs_rev[idx, k]] = i

In [96]:
Q_phoneme = torch.zeros(len(flat_idxs), dtype=torch.long) - 1
flat_idxs_rev = {idx: i for i, idx in enumerate(flat_idxs)}

for i, key in enumerate(phoneme_within):
    for idx, k in phoneme_within[key]:
        Q_phoneme[flat_idxs_rev[idx, k]] = i

In [138]:
# For each frame, store the preceding frame at which the word event began.
S = torch.zeros(len(flat_idxs), dtype=torch.long) - 1

def compute_S(item, idx):
    flat_idx_offset, flat_idx_end = frames_by_item[idx]
    num_frames = flat_idx_end - flat_idx_offset
    compression_ratio = num_frames / len(item["input_values"])

    for word in item["word_phonemic_detail"]:
        word_str = tuple(phone["phone"] for phone in word)
        word_start = int(word[0]["start"] * compression_ratio)
        word_end = int(word[-1]["stop"] * compression_ratio)

        for j in range(word_start, word_end + 1):
            S[flat_idx_offset + j] = word_start

    return None

dev_dataset.map(compute_S, with_indices=True)

# If Q is set, then S should be set
# (Q != -1) => (S != -1)
assert ((Q_phoneme == -1) | (S != -1)).all()

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [145]:
def prepare_dataloader(F, Q, S, max_length, batch_size=32):
    dataset = []
    assert F.shape[0] == Q.shape[0] == S.shape[0]
    n_F = F.size(0)

    lengths = torch.arange(n_F) - S
    lengths = torch.minimum(lengths, torch.tensor(max_length))
    # TODO this is just a hack
    lengths[lengths == 0] = 1
    lengths[S == -1] = -1

    for i in range(n_F):
        if lengths[i] == -1:
            continue

        pos_indices = (Q == Q[i]).nonzero(as_tuple=True)[0]
        neg_indices = ((Q != -1) & (Q != Q[i])).nonzero(as_tuple=True)[0]

        if len(pos_indices) > 1 and len(neg_indices) > 0:
            pos_indices = pos_indices[pos_indices != i]
            pos_idx = random.choice(pos_indices)
            neg_idx = random.choice(neg_indices)

            # Extract sequences
            example_seq = get_sequence(F, S[i], i, max_length)
            pos_seq = get_sequence(F, S[pos_idx], pos_idx, max_length)
            neg_seq = get_sequence(F, S[neg_idx], neg_idx, max_length)

            dataset.append((example_seq, lengths[i],
                            pos_seq, lengths[pos_idx],
                            neg_seq, lengths[neg_idx]))

    return DataLoader(TensorDataset(
        # example frames and lengths
        torch.stack([x[0] for x in dataset]), 
        torch.stack([x[1] for x in dataset]), 

        # positive frames and lengths
        torch.stack([x[2] for x in dataset]),
        torch.stack([x[3] for x in dataset]),

        # negative frames and lengths
        torch.stack([x[4] for x in dataset]),
        torch.stack([x[5] for x in dataset])),
        batch_size=batch_size, shuffle=True)

In [146]:
layer = 6
output_dim = 32
dataloader = prepare_dataloader(frame_states[:, layer, :], Q_full, S, max_length, batch_size=32)

In [147]:
ce_model = ContrastiveEmbeddingModel(input_dim=frame_states.shape[-1], hidden_dim=256,
                                     output_dim=output_dim, tau=0.1)

In [148]:
dataloader.dataset[0]

(tensor([[-0.2177, -0.0353, -0.2646,  ..., -0.1344,  0.1016, -0.1929],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor(1),
 tensor([[-0.2177, -0.0353, -0.2646,  ..., -0.1344,  0.1016, -0.1929],
         [-0.2657,  0.1511, -0.2867,  ..., -0.0080, -0.1367, -0.1434],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor(1),
 tensor([[ 0.1642,  0.3462, -0.1093,  ..., -0.1092,  0.0529, -0.0363],
         [ 0.2049,  0

In [149]:
ce_model(next(iter(dataloader)))

tensor(-4.7878, grad_fn=<AddBackward0>)