# Word Segmentation on Buckeye VQ-VAE Codes

Herman Kamper, 2021

Train a segmental autoencoding recurrent neural network (segmental AE-RNN) and perform word segmentation on VQVAE-encoded Buckeye.

## Preliminaries

In [1]:
from datetime import datetime
from pathlib import Path
from scipy.stats import gamma
from sklearn import cluster
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import sys
import torch
import torch.nn as nn

sys.path.append("..")

from seg_aernn import datasets, models, viterbi
from utils import eval_segmentation

## Utility functions

In [2]:
def get_segmented_sentence(ids, boundaries):
    output = ""
    cur_word = []
    for i_symbol, boundary in enumerate(boundaries):
        cur_word.append(id_to_symbol[ids[i_symbol]])
        if boundary:
            output += "_".join(cur_word)
            output += " "
            cur_word = []
    return output.strip()

In [3]:
# Duration penalty functions

# Histogram
histogram = np.array([
    0, 1.66322800e-01, 2.35838129e-01, 2.10609187e-01,
    1.48025482e-01, 9.42918160e-02, 5.84211098e-02, 3.64679480e-02,
    2.18264741e-02, 1.25420784e-02, 7.18500018e-03, 4.27118399e-03,
    1.73743077e-03, 1.19448366e-03, 7.42027726e-04, 2.89571796e-04,
    2.35277084e-04, 0.00001, 0.00001, 0.00001, 0.00001, 0.00001
    ])  # to-do: check this
histogram = histogram/np.sum(histogram)
def neg_log_hist(dur):
    return -np.log(0 if dur >= len(histogram) else histogram[dur])

# Cached Gamma
shape, loc, scale = (2.3, 0, 1.3)
gamma_cache = []
for dur in range(200):
    gamma_cache.append(gamma.pdf(dur, shape, loc, scale))
gamma_cache = np.array(gamma_cache)/np.sum(gamma_cache)
def neg_log_gamma(dur):
    if dur < 200:
        return -np.log(gamma_cache[dur])
    else:
        return -np.log(0)

# Chorowski
def neg_chorowski(dur):
    return -(dur - 1)

## Data

In [4]:
# Dataset
vq_model = "vqvae"
dataset = "buckeye"
split = "val"
seg_tag = "phoneseg_dp_penalized"

# Paths
seg_dir = (
    Path("../../../vqwordseg/exp")/vq_model/dataset/split/seg_tag/"intervals"
    )
word_ref_dir = Path("../../../vqwordseg/data")/dataset/"word_intervals"

In [5]:
# Read phone segmentation
phoneseg_interval_dict = {}
print("Reading: {}".format(seg_dir))
phoneseg_interval_dict = eval_segmentation.get_intervals_from_dir(seg_dir)
utterances = phoneseg_interval_dict.keys()

 12%|█▏        | 2046/16512 [00:00<00:00, 20455.68it/s]

Reading: ../../../vqwordseg/exp/vqvae/buckeye/val/phoneseg_dp_penalized/intervals


100%|██████████| 16512/16512 [00:00<00:00, 26968.88it/s]


In [6]:
# Read word reference
print("Reading: {}".format(word_ref_dir))
word_ref_interval_dict = eval_segmentation.get_intervals_from_dir(word_ref_dir, utterances)

  0%|          | 0/16512 [00:00<?, ?it/s]

Reading: ../../../vqwordseg/data/buckeye/word_intervals


100%|██████████| 16512/16512 [00:00<00:00, 38498.63it/s]


In [7]:
# Convert intervals to boundaries
word_ref_boundaries_dict = {}
for utt_key in tqdm(word_ref_interval_dict):
    word_ref_boundaries_dict[utt_key] = eval_segmentation.intervals_to_boundaries(
        word_ref_interval_dict[utt_key]
        )

100%|██████████| 16512/16512 [00:00<00:00, 332689.70it/s]


In [8]:
prepared_text = []
for utt_key in tqdm(utterances):
    prepared_text.append(
        "_".join([i[2] for i in phoneseg_interval_dict[utt_key]])
        )
    
print(prepared_text[0])

100%|██████████| 16512/16512 [00:00<00:00, 870579.59it/s]

307_343_461_225_435_84_144_125_443_332_42_101_201_202





In [9]:
# Gold segmentation, where boundaries are inserted in best possible positions
n_not_in_tolerance = 0
prepared_text_gold = []
for utt_key in tqdm(utterances):
    seg_intervals = phoneseg_interval_dict[utt_key].copy()
    ref_intervals = word_ref_interval_dict[utt_key].copy()
    seg_boundaries = np.array([i[1] - 1 for i in seg_intervals])
    ref_boundaries = np.array([i[1] - 1 for i in ref_intervals])
    for ref_boundary in ref_boundaries[:-1]:
        i_seg = np.argmin(np.abs(seg_boundaries - ref_boundary))
        seg_intervals.insert(
            i_seg + 1, (seg_intervals[i_seg][1], seg_intervals[i_seg][1], " ")
            )
        seg_boundaries = np.array([i[1] - 1 for i in seg_intervals])
    cur_text_gold = ""
    for start, end, label in seg_intervals:
        if label == " ":
            cur_text_gold = cur_text_gold[:-1]
            cur_text_gold += " "
        else:
            cur_text_gold += label + "_"
    cur_text_gold = cur_text_gold[:-1]
    prepared_text_gold.append(cur_text_gold)

print(prepared_text_gold[0])

100%|██████████| 16512/16512 [00:00<00:00, 32857.56it/s]

307_343_461_225 435_84_144_125_443_332_42_101 201_202





In [10]:
# Vocabulary
PAD_SYMBOL      = "<pad>"
SOS_SYMBOL      = "<s>"    # start of sentence
EOS_SYMBOL      = "</s>"   # end of sentence
BOUNDARY_SYMBOL = " "      # word boundary
symbols = set()
for sentence in prepared_text:
    for char in sentence.split("_"):
        symbols.add(char)
SYMBOLS = [PAD_SYMBOL, SOS_SYMBOL, EOS_SYMBOL, BOUNDARY_SYMBOL] + (sorted(list(symbols)))
symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)}
id_to_symbol = {i: s for i, s in enumerate(SYMBOLS)}

def text_to_id(text, add_sos_eos=False):
    """
    Convert text to a list of symbol IDs.

    Sentence start and end symbols can be added by setting `add_sos_eos`.
    """
    symbol_ids = []
    for word in text.split(" "):
        for code in word.split("_"):
            symbol_ids.append(symbol_to_id[code])
        symbol_ids.append(symbol_to_id[BOUNDARY_SYMBOL])
    symbol_ids = symbol_ids[:-1]  # remove last space

    if add_sos_eos:
        return [symbol_to_id[SOS_SYMBOL]] + symbol_ids + [symbol_to_id[EOS_SYMBOL]]
    else:
        return symbol_ids

print(text_to_id(prepared_text[0]))
print([id_to_symbol[i] for i in text_to_id(prepared_text[0])])

[231, 271, 398, 142, 369, 487, 54, 33, 378, 259, 352, 8, 116, 117]
['307', '343', '461', '225', '435', '84', '144', '125', '443', '332', '42', '101', '201', '202']


In [11]:
# First three words of training data
word_dataset = datasets.WordDataset(prepared_text, text_to_id)
for i in range(7):
    sample = word_dataset[i]
    print("_".join([id_to_symbol[i] for i in sample.numpy()]))

307_343_461_225_435_84_144_125_443_332_42_101_201_202
231_291_74_51_254
246_252_281_144_73_412_234_69_444_277_424_277_446_2
444_170_312_483_350_131_461_446
86_359_449_225_84_250_452_449_223_139_109_235_210_490_394_461_301
2_204_97_436_66_264_150_465_223_488
323_432_172_374_215_484_505_205_276_324_87_170_233_430_366_358_450_453_494_213_417_323_213


In [12]:
# Data

# Approximate ground truth (for debugging)
# cur_train_sentences = prepared_text_gold[:10000]
cur_val_sentences = prepared_text_gold[-1000:]

# No boundaries
cur_train_sentences = prepared_text[:10000]
# cur_val_sentences = prepared_text[-1000:]

# Random boundaries
np.random.seed(42)
# cur_train_sentences = insert_random_boundaries(cur_train_sentences)
# cur_val_sentences = insert_random_boundaries(cur_val_sentences)

print("No. train sentences:", len(cur_train_sentences))
print("Examples:", cur_train_sentences[:3])
print("Min length: ", min([len(i.split("_")) for i in cur_train_sentences]))
print("Max length: ", max([len(i.split("_")) for i in cur_train_sentences]))
print("Mean length: {:.4f}".format(np.mean([len(i.split("_")) for i in cur_train_sentences])))

No. train sentences: 10000
Examples: ['307_343_461_225_435_84_144_125_443_332_42_101_201_202', '231_291_74_51_254', '246_252_281_144_73_412_234_69_444_277_424_277_446_2']
Min length:  1
Max length:  106
Mean length: 11.3592


## Model

In [13]:
# AE-RNN model
n_symbols = len(SYMBOLS)
symbol_embedding_dim = 25  # 25
hidden_dim = 500  # 1000  # 200
embedding_dim = 150  # 300  # 25
teacher_forcing_ratio = 0.5  # 1.0  # 0.5  # 1.0
n_encoder_layers = 1  # 1  # 3  # 10
n_decoder_layers = 1  # 1
batch_size = 32  # 32*3  # 32
learning_rate = 0.001
input_dropout = 0.0  # 0.5
dropout = 0.0
n_symbols_max = 25
n_epochs_max = 25

encoder = models.Encoder(
    n_symbols=n_symbols,
    symbol_embedding_dim=symbol_embedding_dim,
    hidden_dim=hidden_dim,
    embedding_dim=embedding_dim,
    n_layers=n_encoder_layers,
    dropout=dropout,
    input_dropout=input_dropout
    )
decoder = models.Decoder1(
    n_symbols=n_symbols,
    symbol_embedding_dim=symbol_embedding_dim,
    hidden_dim=hidden_dim,
    embedding_dim=embedding_dim,
    n_layers=n_decoder_layers,
    sos_id = symbol_to_id[SOS_SYMBOL],
    teacher_forcing_ratio=teacher_forcing_ratio,
    dropout=dropout
    )
# decoder = Decoder2(
#     n_symbols=n_symbols,
#     hidden_dim=hidden_dim,
#     embedding_dim=embedding_dim,
#     n_layers=n_decoder_layers,
#     dropout=dropout
#     )
model = models.EncoderDecoder(encoder, decoder)

## Pre-training

In [14]:
# Training device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Training data
train_dataset = datasets.WordDataset(
    cur_train_sentences, text_to_id, n_symbols_max=n_symbols_max
    )
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=datasets.pad_collate
    )

# Validation data
val_dataset = datasets.WordDataset(cur_val_sentences, text_to_id)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=datasets.pad_collate
    )

# Loss
criterion = nn.NLLLoss(
    reduction="sum", ignore_index=symbol_to_id[PAD_SYMBOL]
    )
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

for i_epoch in range(n_epochs_max):

    # Training
    model.train()
    train_losses = []
    for i_batch, (data, data_lengths) in enumerate(tqdm(train_loader)):
        optimiser.zero_grad()
        data = data.to(device)       
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )

        loss = criterion(
            decoder_output.contiguous().view(-1, decoder_output.size(-1)),
            data.contiguous().view(-1)
            )
        loss /= len(data_lengths)
        loss.backward()
        optimiser.step()
        train_losses.append(loss.item())
    
    # Validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        for i_batch, (data, data_lengths) in enumerate(val_loader):
            data = data.to(device)            
            encoder_embedding, decoder_output = model(
                data, data_lengths, data, data_lengths
                )

            loss = criterion(
                decoder_output.contiguous().view(-1,
                decoder_output.size(-1)), data.contiguous().view(-1)
                )
            loss /= len(data_lengths)
            val_losses.append(loss.item())
    
    print(
        "Epoch {}, train loss: {:.3f}, val loss: {:.3f}".format(
        i_epoch,
        np.mean(train_losses),
        np.mean(val_losses))
        )
    sys.stdout.flush()


100%|██████████| 313/313 [00:05<00:00, 54.64it/s]


Epoch 0, train loss: 58.955, val loss: 14.040


100%|██████████| 313/313 [00:05<00:00, 54.54it/s]


Epoch 1, train loss: 50.525, val loss: 9.919


100%|██████████| 313/313 [00:05<00:00, 54.54it/s]


Epoch 2, train loss: 45.519, val loss: 8.245


100%|██████████| 313/313 [00:05<00:00, 53.97it/s]


Epoch 3, train loss: 41.384, val loss: 6.841


100%|██████████| 313/313 [00:05<00:00, 54.08it/s]


Epoch 4, train loss: 37.725, val loss: 6.155


100%|██████████| 313/313 [00:05<00:00, 54.56it/s]


Epoch 5, train loss: 34.620, val loss: 5.919


100%|██████████| 313/313 [00:05<00:00, 54.21it/s]


Epoch 6, train loss: 31.914, val loss: 5.785


100%|██████████| 313/313 [00:05<00:00, 54.20it/s]


Epoch 7, train loss: 29.568, val loss: 5.879


100%|██████████| 313/313 [00:05<00:00, 54.55it/s]


Epoch 8, train loss: 27.478, val loss: 5.858


100%|██████████| 313/313 [00:05<00:00, 54.55it/s]


Epoch 9, train loss: 25.830, val loss: 5.974


100%|██████████| 313/313 [00:05<00:00, 54.56it/s]


Epoch 10, train loss: 24.263, val loss: 6.046


100%|██████████| 313/313 [00:05<00:00, 54.53it/s]


Epoch 11, train loss: 22.963, val loss: 6.312


100%|██████████| 313/313 [00:05<00:00, 54.72it/s]


Epoch 12, train loss: 21.810, val loss: 6.368


100%|██████████| 313/313 [00:05<00:00, 54.57it/s]


Epoch 13, train loss: 20.759, val loss: 6.625


100%|██████████| 313/313 [00:05<00:00, 54.12it/s]


Epoch 14, train loss: 19.840, val loss: 6.785


100%|██████████| 313/313 [00:05<00:00, 54.62it/s]


Epoch 15, train loss: 19.167, val loss: 7.021


100%|██████████| 313/313 [00:05<00:00, 53.41it/s]


Epoch 16, train loss: 18.403, val loss: 6.956


100%|██████████| 313/313 [00:06<00:00, 52.11it/s]


Epoch 17, train loss: 17.695, val loss: 6.972


100%|██████████| 313/313 [00:06<00:00, 48.84it/s]


Epoch 18, train loss: 17.027, val loss: 7.193


100%|██████████| 313/313 [00:06<00:00, 51.68it/s]


Epoch 19, train loss: 16.483, val loss: 7.429


100%|██████████| 313/313 [00:05<00:00, 52.22it/s]


Epoch 20, train loss: 16.014, val loss: 7.511


100%|██████████| 313/313 [00:06<00:00, 50.61it/s]


Epoch 21, train loss: 15.669, val loss: 7.751


100%|██████████| 313/313 [00:06<00:00, 51.43it/s]


Epoch 22, train loss: 15.258, val loss: 7.721


100%|██████████| 313/313 [00:05<00:00, 53.33it/s]


Epoch 23, train loss: 14.667, val loss: 7.668


100%|██████████| 313/313 [00:05<00:00, 52.54it/s]


Epoch 24, train loss: 14.302, val loss: 7.726


In [15]:
# Examples without segmentation

# Apply to validation data
model.eval()
with torch.no_grad():
    for i_batch, (data, data_lengths) in enumerate(val_loader):
        data = data.to(device)
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )
        
        y, log_probs = model.decoder.greedy_decode(
            encoder_embedding,
            max_length=25,
            )
        x = data.cpu().numpy()
        
        for i_input in range(y.shape[0]):
            # Only print up to EOS symbol
            input_symbols = []
            for i in x[i_input]:
                if i == symbol_to_id[EOS_SYMBOL] or i == symbol_to_id[PAD_SYMBOL]:
                    break
                input_symbols.append(id_to_symbol[i])
            output_symbols = []
            for i in y[i_input]:
                if i == symbol_to_id[EOS_SYMBOL] or i == symbol_to_id[PAD_SYMBOL]:
                    break
                output_symbols.append(id_to_symbol[i])

            print("Input: ", "_".join(input_symbols))
            print("Output:", "_".join(output_symbols))
            print()
            
            if i_input == 10:
                break
        
        break

Input:  342_213_258_147
Output: 342_40_258_147_213_258_147_258_147_258_147_258_147_258_147_258_147_258_222_258_147_258_147_258_222

Input:  91_431_10
Output: 91_91_431_10_10_151_151_10_151_431_431_151_151_164_148_151_90_277_90_184_117_90_164_148_151

Input:  323_317_491_339_144_4_225_88
Output: 323_260_491_19_225_88_144_339_144_339_144_339_144_339_144_88_225_144_339_144_88_225_88_144_339

Input:  123_109
Output: 123_109_109_109_109_109_109_109_109_54_109_293_486_109_109_486_109_320_486_109_486_109_320_486_109

Input:  217_246
Output: 217_217_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246_246

Input:  192
Output: 192_192_192_192_192_192_192_192_425_158_196_223_425_158_425_196_290_141_468_466_203_21_135_484_353

Input:  210_195_118_466_156
Output: 210_466_466_156_466_156_466_156_164_156_466_164_24_156_164_192_164_164_156_466_156_164_164_156_164

Input:  319_391_463_31_242
Output: 391_391_463_242_31_31_463_31_242_31_242_31_242_171_463_31_242_171_17

## Segmentation

In [16]:
# Utterances for evaluation
n_eval_utterances = 1000 # 10000 # 1000
# eval_sentences = prepared_text[-n_eval_utterances:]  # val sentences
# eval_utterances = list(utterances)[-n_eval_utterances:]
eval_sentences = prepared_text[:n_eval_utterances]
eval_utterances = list(utterances)[:n_eval_utterances]

In [17]:
# Embed segments

# Random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Data
sentences = eval_sentences
# sentences = cur_val_sentences
interval_dataset = datasets.SentenceIntervalDataset(
    sentences,
    text_to_id,
    join_char="_"
    )
segment_loader = DataLoader(
    interval_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    collate_fn=datasets.pad_collate,
    drop_last=False
    )

# Apply model to data
model.decoder.teacher_forcing_ratio = 1.0
model.eval()
rnn_losses = []
lengths = []
eos = []
with torch.no_grad():
    for i_batch, (data, data_lengths) in enumerate(tqdm(segment_loader)):
        data = data.to(device)
        
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )

        for i_item in range(data.shape[0]):
            item_loss = criterion(
                decoder_output[i_item].contiguous().view(-1,
                decoder_output[i_item].size(-1)),
                data[i_item].contiguous().view(-1)
                )
            rnn_losses.append(item_loss)
            lengths.append(data_lengths[i_item])

100%|██████████| 2982/2982 [00:14<00:00, 208.70it/s]


In [18]:
# Segment

# dur_weight = 12.0 #  2.5  # 12  # 2.5  # Chorowski
dur_weight = 1.0

i_item = 0
losses = []
cur_segmented_sentences = []
for i_sentence, intervals in enumerate(tqdm(interval_dataset.intervals)):
    
    # Costs for segment intervals
    costs = np.inf*np.ones(len(intervals))
    i_eos = intervals[-1][-1]
    for i_seg, interval in enumerate(intervals):
        if interval is None:
            continue
        i_start, i_end = interval
        dur = i_end - i_start
        assert dur == lengths[i_item]
        eos = (i_end == i_eos)  # end-of-sequence
        
#         # Chorowski
#         costs[i_seg] = (
#             rnn_losses[i_item]
#             + dur_weight*neg_chorowski(dur)
#             )
        
#         # Gamma
#         costs[i_seg] = (
#             rnn_losses[i_item]
#             + dur_weight*neg_log_gamma(dur)
#             + np.log(np.sum(gamma_cache**dur_weight))
#             )
        
#         # Poisson
#         costs[i_seg] = (
#             rnn_losses[i_item]
#             + neg_log_poisson(dur)
#             )

        # Histogram
        costs[i_seg] = (
            rnn_losses[i_item]
            + dur_weight*(neg_log_hist(dur))
            + np.log(np.sum(histogram**dur_weight))
            )
    
        # Sequence boundary
        alpha = 0.3  # 0.9
        if eos:
            costs[i_seg] += -np.log(alpha)
        else:
            costs[i_seg] += -np.log(1 - alpha)

        # Temp
#         if dur > 10 or dur <= 1:
#             costs[i_seg] = +np.inf
        i_item += 1
    
    # Viterbi segmentation
    n_frames = len(interval_dataset.sentences[i_sentence])
    summed_cost, boundaries = viterbi.custom_viterbi(costs, n_frames)
    losses.append(summed_cost)
    
    reference_sentence = sentences[i_sentence]
    segmented_sentence = get_segmented_sentence(
            interval_dataset.sentences[i_sentence],
            boundaries
            )
    cur_segmented_sentences.append(segmented_sentence)
#     # Print examples of the first few sentences
#     if i_sentence < 10:
#         print(reference_sentence)
#         print(segmented_sentence)
#         print()
    
print("NLL: {:.4f}".format(np.sum(losses)))

100%|██████████| 1000/1000 [00:06<00:00, 161.95it/s]

NLL: 11245.1908





In [19]:
print(cur_segmented_sentences[0])

# To evaluate gold segmentation
# cur_segmented_sentences = prepared_text_gold[:n_eval_utterances]
# print(cur_segmented_sentences[0])

307_343 461_225 435_84 144_125_443 332 42_101 201_202


## Evaluation

In [20]:
# Convert segmentation to intervals
segmentation_interval_dict = {}
for i_utt, utt_key in tqdm(enumerate(eval_utterances)):
    words_segmented = cur_segmented_sentences[i_utt].split(" ")
    word_start = 0
    word_label = ""
    i_word = 0
    segmentation_interval_dict[utt_key] = []
    for (phone_start, phone_end,
            phone_label) in phoneseg_interval_dict[utt_key]:
        word_label += phone_label + "_"
        if words_segmented[i_word] == word_label[:-1]:
            segmentation_interval_dict[utt_key].append((
                word_start, phone_end, word_label[:-1]
                ))
            word_label = ""
            word_start = phone_end
            i_word += 1

    if i_utt < 10:
        print(segmentation_interval_dict[utt_key])
        print(word_ref_interval_dict[utt_key])
        print()        

1000it [00:00, 114087.26it/s]

[(0, 10, '307_343'), (10, 16, '461_225'), (16, 24, '435_84'), (24, 38, '144_125_443'), (38, 44, '332'), (44, 52, '42_101'), (52, 62, '201_202')]
[(0, 19, 'those'), (19, 55, 'parents'), (55, 64, 'were')]

[(0, 26, '231_291_74_51_254')]
[(0, 11, 'i'), (11, 29, 'went')]

[(0, 28, '246_252_281_144_73'), (28, 40, '412_234'), (40, 50, '69'), (50, 68, '444_277'), (68, 90, '424_277'), (90, 98, '446_2')]
[(0, 20, 'no'), (20, 33, 'in'), (33, 57, 'one'), (57, 99, 'week')]

[(0, 48, '444_170_312_483_350_131_461_446')]
[(0, 9, 'when'), (9, 15, 'it'), (15, 42, 'happened'), (42, 49, 'in')]

[(0, 16, '86_359_449'), (16, 36, '225_84_250'), (36, 56, '452_449_223'), (56, 66, '139_109'), (66, 88, '235_210_490'), (88, 100, '394_461_301')]
[(0, 10, 'it'), (10, 25, 'was'), (25, 59, 'useless'), (59, 69, 'to'), (69, 95, 'fight'), (95, 103, 'it')]

[(0, 24, '2_204_97_436'), (24, 54, '66_264_150_465'), (54, 70, '223_488')]
[(0, 15, 'and'), (15, 31, 'not'), (31, 71, 'just')]

[(0, 28, '323_432_172_374_215'), (28,




In [21]:
# Intervals to boundaries
segmentation_boundaries_dict = {}
for utt_key in tqdm(segmentation_interval_dict):
    segmentation_boundaries_dict[utt_key] = eval_segmentation.intervals_to_boundaries(
        segmentation_interval_dict[utt_key]
        )
word_ref_boundaries_dict = {}
for utt_key in tqdm(word_ref_interval_dict):
    word_ref_boundaries_dict[utt_key] = eval_segmentation.intervals_to_boundaries(
        word_ref_interval_dict[utt_key]
        )

# Evaluate word boundaries
reference_list = []
segmentation_list = []
for utterance in segmentation_boundaries_dict:
    reference_list.append(word_ref_boundaries_dict[utterance])
    segmentation_list.append(segmentation_boundaries_dict[utterance])

tolerance = 2
p, r, f = eval_segmentation.score_boundaries(
    reference_list, segmentation_list, tolerance=tolerance
    )
print("-"*(79 - 4))
print("Word boundaries:")
print("Precision: {:.2f}%".format(p*100))
print("Recall: {:.2f}%".format(r*100))
print("F-score: {:.2f}%".format(f*100))
print("OS: {:.2f}%".format(eval_segmentation.get_os(p, r)*100))
print("R-value: {:.2f}%".format(eval_segmentation.get_rvalue(p, r)*100))
print("-"*(79 - 4))

p, r, f = eval_segmentation.score_word_token_boundaries(
    reference_list, segmentation_list, tolerance=tolerance
    )
print("Word token boundaries:")
print("Precision: {:.2f}%".format(p*100))
print("Recall: {:.2f}%".format(r*100))
print("F-score: {:.2f}%".format(f*100))
print("OS: {:.2f}%".format(eval_segmentation.get_os(p, r)*100))
# print("R-value: {:.2f}%".format(get_rvalue(p, r)*100))
print("-"*(79 - 4))

100%|██████████| 1000/1000 [00:00<00:00, 292673.50it/s]
100%|██████████| 16512/16512 [00:00<00:00, 310726.82it/s]

---------------------------------------------------------------------------
Word boundaries:
Precision: 22.52%
Recall: 28.55%
F-score: 25.18%
OS: 26.79%
R-value: 27.12%
---------------------------------------------------------------------------
Word token boundaries:
Precision: 13.25%
Recall: 15.79%
F-score: 14.41%
OS: 19.11%
---------------------------------------------------------------------------





## Quantization

In [22]:
# clustering_sentences = prepared_text_gold[:10000]
clustering_sentences = cur_segmented_sentences

In [23]:
# K-means centroids

# Data
train_dataset = datasets.WordDataset(
    clustering_sentences, text_to_id, n_symbols_max=n_symbols_max
    )
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=datasets.pad_collate
    )

# Apply model to data
model.eval()
encoder_embeddings = []
with torch.no_grad():
    for i_batch, (data, data_lengths) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )
        encoder_embeddings.append(encoder_embedding.cpu().numpy())
        
# Cluster
X = np.vstack(encoder_embeddings)
print("X shape:", X.shape)
print(datetime.now())
K = 1024  # 1024  # 2048
print("Clustering: K = {}".format(K))
vq_model = cluster.KMeans(n_clusters=K, max_iter=10)
vq_model.fit(X)
print("Inertia: {:.4f}".format(vq_model.inertia_))
centroids = vq_model.cluster_centers_
print(datetime.now())

100%|██████████| 130/130 [00:00<00:00, 400.74it/s]


X shape: (4157, 150)
2021-07-26 15:18:19.410862
Clustering: K = 1024
Inertia: 425332.2500
2021-07-26 15:18:41.403616


In [24]:
# Examples without segmentation

# Apply to validation data
model.eval()
with torch.no_grad():
    for i_batch, (data, data_lengths) in enumerate(val_loader):
#     for i_batch, (data, data_lengths) in enumerate(train_loader):
        data = data.to(device)
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )

        encoder_embedding = encoder_embedding.cpu().numpy()
        clusters = vq_model.predict(encoder_embedding)
        embedding_reconstructed = centroids[clusters, :].reshape(
            encoder_embedding.shape
            )
        embedding_reconstructed = torch.from_numpy(
            embedding_reconstructed
            ).to(device)
        
        y, log_probs = model.decoder.greedy_decode(
            embedding_reconstructed,
            max_length=n_symbols_max,
            )
        x = data.cpu().numpy()
        
        for i_input in range(y.shape[0]):
            # Only print up to EOS symbol
            input_symbols = []
            for i in x[i_input]:
                if i == symbol_to_id[EOS_SYMBOL] or i == symbol_to_id[PAD_SYMBOL]:
                    break
                input_symbols.append(id_to_symbol[i])
            output_symbols = []
            for i in y[i_input]:
                if i == symbol_to_id[EOS_SYMBOL] or i == symbol_to_id[PAD_SYMBOL]:
                    break
                output_symbols.append(id_to_symbol[i])

            print("Input: ", "_".join(input_symbols))
            print("Output:", "_".join(output_symbols))
            print()
            
            if i_input == 10:
                break
        
        break

Input:  97
Output: 97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_97_182_414

Input:  476_89_81_225
Output: 225_225_225_225_225_225_225_97_225_225_225_232_225_232_225_225_232_225_232_225_225_232_144_232_225

Input:  406_282
Output: 394_394_394_394_394_394_394_394_394_394_394_394_394_394_394_394_173_321_321_97_394_219_321_394_394

Input:  43_210_443_486
Output: 486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486_486

Input:  449_428_233_429_208_223
Output: 208_208_208_263_208_31_263_208_485_208_208_218_97_332_208_245_97_504_208_245_245_208_504_332_97

Input:  290_95_433
Output: 486_433_433_433_433_433_433_433_433_433_433_449_433_433_433_433_433_449_433_45_433_433_433_449_449

Input:  490
Output: 21_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490_490

Input:  162_346
Output: 330_79_79_79_79_79_470_79_79_79_79_79_79_253_79_79_79_79_253_483_79_79_253_79_79

Input:  254_272
Output: 261

In [25]:
# Utterances for evaluation
n_eval_utterances = 1000
# eval_sentences = prepared_text[-n_eval_utterances:]  # val sentences
# eval_utterances = list(utterances)[-n_eval_utterances:]
eval_sentences = prepared_text[:n_eval_utterances]
eval_utterances = list(utterances)[:n_eval_utterances]

In [26]:
# Embed segments

# Random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Data
sentences = eval_sentences
interval_dataset = datasets.SentenceIntervalDataset(
    sentences,
    text_to_id,
    "_"
    )
segment_loader = DataLoader(
    interval_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    collate_fn=datasets.pad_collate,
    drop_last=False
    )

# Apply model to data
model.decoder.teacher_forcing_ratio = 1.0  # to-do: adjust this
model.eval()
rnn_losses = []
lengths = []
with torch.no_grad():
    for i_batch, (data, data_lengths) in enumerate(tqdm(segment_loader)):
        data = data.to(device)

        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )

        encoder_embedding = encoder_embedding.cpu().numpy()
        clusters = vq_model.predict(encoder_embedding)
        embedding_reconstructed = centroids[clusters, :].reshape(
            encoder_embedding.shape
            )
        embedding_reconstructed = torch.from_numpy(
            embedding_reconstructed
            ).to(device)
        
        decoder_rnn, decoder_output = model.decoder(
            embedding_reconstructed, data, data_lengths
            )

        for i_item in range(data.shape[0]):
            item_loss = criterion(
                decoder_output[i_item].contiguous().view(-1,
                decoder_output[i_item].size(-1)),
                data[i_item].contiguous().view(-1)
                )
            rnn_losses.append(item_loss)
            lengths.append(data_lengths[i_item])

100%|██████████| 2982/2982 [00:25<00:00, 116.28it/s]


Options:

- Want to evaluate this segmentation: Go back up to the cell where segmentation is done (after segments are embedded).
- Want to retrain K-means model based on this segmentation: Go back to start of quantization cell.