# Word Segmentation on ZeroSpeech'17 Mandarin

Copyright (c) 2021 Herman Kamper, MIT License

Train a segmental autoencoding recurrent neural network (segmental AE-RNN) and perform word segmentation on encoded ZeroSpeech'17 Mandarin.

## Preliminaries

In [7]:
from datetime import datetime
from pathlib import Path
from scipy.stats import gamma
from sklearn import cluster
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import torch.nn as nn

sys.path.append("..")

from seg_aernn import datasets, models, viterbi
from utils import eval_segmentation

## Utility functions

In [8]:
def get_segmented_sentence(ids, boundaries):
    output = ""
    cur_word = []
    for i_symbol, boundary in enumerate(boundaries):
        cur_word.append(id_to_symbol[ids[i_symbol]])
        if boundary:
            output += "_".join(cur_word)
            output += " "
            cur_word = []
    return output.strip()

In [9]:
# Duration penalty functions

"""
# Histogram
histogram = np.array([
    0, 1.66322800e-01, 2.35838129e-01, 2.10609187e-01,
    1.48025482e-01, 9.42918160e-02, 5.84211098e-02, 3.64679480e-02,
    2.18264741e-02, 1.25420784e-02, 7.18500018e-03, 4.27118399e-03,
    1.73743077e-03, 1.19448366e-03, 7.42027726e-04, 2.89571796e-04,
    2.35277084e-04, 0.00001, 0.00001, 0.00001, 0.00001, 0.00001
    ])  # to-do: check this
histogram = histogram/np.sum(histogram)
def neg_log_hist(dur):
    return -np.log(0 if dur >= len(histogram) else histogram[dur])

# Cached Gamma
# shape, loc, scale = (2.3, 0, 1.3)  # VQ-VAE
shape, loc, scale = (2.6, 0, 1.8)    # CPC-big
# shape, loc, scale = (2.5, 0, 1.5)    # CPC-big (Gamma)
gamma_cache = []
for dur in range(200):
    gamma_cache.append(gamma.pdf(dur, shape, loc, scale))
gamma_cache = np.array(gamma_cache)/np.sum(gamma_cache)
def neg_log_gamma(dur):
    if dur < 200:
        return -np.log(gamma_cache[dur])
    else:
        return -np.log(0)
"""
    
# Chorowski
def neg_chorowski(dur):
    return -(dur - 1)

## Data

In [10]:
# Dataset
vq_model = "cpc_big"
# vq_model = "xlsr"
# dataset = "zs2017_zh"
dataset = "zs2017_zh"
split = "train"
seg_tag = "phoneseg_dp_penalized"
# seg_tag = "phoneseg_dp_penalized_tune"

# Paths
seg_dir = (
    Path("../../vqwordseg/exp")/vq_model/dataset/split/seg_tag/"intervals"
    )
# word_ref_dir = Path("../../vqwordseg/data")/dataset/"word_intervals"

In [11]:
# Read phone segmentation
phoneseg_interval_dict = {}
print("Reading: {}".format(seg_dir))
phoneseg_interval_dict = eval_segmentation.get_intervals_from_dir(seg_dir)
utterances = phoneseg_interval_dict.keys()

 32%|███▏      | 1689/5285 [00:00<00:00, 16884.74it/s]

Reading: ../../vqwordseg/exp/cpc_big/zs2017_zhtmp/train/phoneseg_dp_penalized/intervals


100%|██████████| 5285/5285 [00:00<00:00, 16734.72it/s]


In [12]:
prepared_text = []
for utt_key in tqdm(utterances):
    prepared_text.append(
        "_".join([i[2] for i in phoneseg_interval_dict[utt_key]])
        )
print(prepared_text[0])

100%|██████████| 5285/5285 [00:00<00:00, 501218.66it/s]

4_37_12_18_39_8_28_18_24_1_18_20_49_36_24_1_31_45_2_13_3





In [13]:
# Vocabulary
PAD_SYMBOL      = "<pad>"
SOS_SYMBOL      = "<s>"    # start of sentence
EOS_SYMBOL      = "</s>"   # end of sentence
BOUNDARY_SYMBOL = " "      # word boundary
symbols = set()
for sentence in prepared_text:
    for char in sentence.split("_"):
        symbols.add(char)
SYMBOLS = [PAD_SYMBOL, SOS_SYMBOL, EOS_SYMBOL, BOUNDARY_SYMBOL] + (sorted(list(symbols)))
symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)}
id_to_symbol = {i: s for i, s in enumerate(SYMBOLS)}

def text_to_id(text, add_sos_eos=False):
    """
    Convert text to a list of symbol IDs.

    Sentence start and end symbols can be added by setting `add_sos_eos`.
    """
    symbol_ids = []
    for word in text.split(" "):
        for code in word.split("_"):
            symbol_ids.append(symbol_to_id[code])
        symbol_ids.append(symbol_to_id[BOUNDARY_SYMBOL])
    symbol_ids = symbol_ids[:-1]  # remove last space

    if add_sos_eos:
        return [symbol_to_id[SOS_SYMBOL]] + symbol_ids + [symbol_to_id[EOS_SYMBOL]]
    else:
        return symbol_ids

print(text_to_id(prepared_text[0]))
print([id_to_symbol[i] for i in text_to_id(prepared_text[0])])

[38, 35, 8, 14, 37, 52, 25, 14, 21, 5, 14, 17, 48, 34, 21, 5, 29, 44, 16, 9, 27]
['4', '37', '12', '18', '39', '8', '28', '18', '24', '1', '18', '20', '49', '36', '24', '1', '31', '45', '2', '13', '3']


In [14]:
# First three words of training data
word_dataset = datasets.WordDataset(prepared_text, text_to_id)
for i in range(7):
    sample = word_dataset[i]
    print("_".join([id_to_symbol[i] for i in sample.numpy()]))

4_37_12_18_39_8_28_18_24_1_18_20_49_36_24_1_31_45_2_13_3
4_25_12_27_31_3_47_1_15_20_22
4_37_38_23_7_11_46_38_18_5_43_40_14_36_9_43_40_11
8_48_18_20_8_41_18_20_29
4_0_32_41_47_1_18_34_44_0_28_23_11_14_34_44_0_41_1_19_31_3_16_0_28_18_47_1_21_8_48_27_29
4_9_47_18_42_1_48_27_20_29_49_43_7_14_39_32_28_18_47_1_18_29_16
4_25_10_37_48_27_5_24_1_19_33_45_25_12_27_31_3_16_44_37_41_18_21_46_44_37_25_40_7_10


In [15]:
# Data

cur_val_sentences = prepared_text[-100:]
cur_train_sentences = prepared_text

# Random boundaries
np.random.seed(42)
# cur_train_sentences = insert_random_boundaries(cur_train_sentences)
# cur_val_sentences = insert_random_boundaries(cur_val_sentences)

print("No. train sentences:", len(cur_train_sentences))
print("Examples:", cur_train_sentences[:3])
print("Min length: ", min([len(i.split("_")) for i in cur_train_sentences]))
print("Max length: ", max([len(i.split("_")) for i in cur_train_sentences]))
print("Mean length: {:.4f}".format(np.mean([len(i.split("_")) for i in cur_train_sentences])))

No. train sentences: 5285
Examples: ['4_37_12_18_39_8_28_18_24_1_18_20_49_36_24_1_31_45_2_13_3', '4_25_12_27_31_3_47_1_15_20_22', '4_37_38_23_7_11_46_38_18_5_43_40_14_36_9_43_40_11']
Min length:  1
Max length:  170
Mean length: 29.0956


## Model

In [16]:
# AE-RNN model
n_symbols = len(SYMBOLS)
symbol_embedding_dim = 10  # 25
hidden_dim = 500  # 250  # 500  # 1000  # 200
embedding_dim = 50  # 150  # 300  # 25
teacher_forcing_ratio = 0.5  # 1.0  # 0.5  # 1.0
n_encoder_layers = 1  # 1  # 3  # 10
n_decoder_layers = 1  # 1  # 1
batch_size = 32  # 32*3  # 32
learning_rate = 0.001
input_dropout = 0.0  # 0.0 # 0.5
dropout = 0.0
n_symbols_max = 50  # 25
# n_epochs_max = 5
n_epochs_max = None  # determined from n_max_steps and batch size
n_steps_max = 1500  # 2500  # 1500  # 1000  # None
# n_steps_max = None  # Only use n_epochs_max
bidirectional_encoder = False  # False

encoder = models.Encoder(
    n_symbols=n_symbols,
    symbol_embedding_dim=symbol_embedding_dim,
    hidden_dim=hidden_dim,
    embedding_dim=embedding_dim,
    n_layers=n_encoder_layers,
    dropout=dropout,
    input_dropout=input_dropout,
    bidirectional=bidirectional_encoder
    )
# decoder = models.Decoder1(
#     n_symbols=n_symbols,
#     symbol_embedding_dim=symbol_embedding_dim,
#     hidden_dim=hidden_dim,
#     embedding_dim=embedding_dim,
#     n_layers=n_decoder_layers,
#     sos_id = symbol_to_id[SOS_SYMBOL],
#     teacher_forcing_ratio=teacher_forcing_ratio,
#     dropout=dropout
#     )
decoder = models.Decoder2(
    n_symbols=n_symbols,
    hidden_dim=hidden_dim,
    embedding_dim=embedding_dim,
    n_layers=n_decoder_layers,
    dropout=dropout
    )
model = models.EncoderDecoder(encoder, decoder)

## Pre-training

In [17]:
# Training device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Training data
train_dataset = datasets.WordDataset(
    cur_train_sentences, text_to_id, n_symbols_max=n_symbols_max
    )
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=datasets.pad_collate
    )

# Validation data
val_dataset = datasets.WordDataset(cur_val_sentences, text_to_id)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=datasets.pad_collate
    )

# Loss
criterion = nn.NLLLoss(
    reduction="sum", ignore_index=symbol_to_id[PAD_SYMBOL]
    )
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

if n_epochs_max is None:
    steps_per_epoch = np.ceil(len(cur_train_sentences)/batch_size)
    n_epochs_max = int(np.ceil(n_steps_max/steps_per_epoch))

i_step = 0
for i_epoch in range(n_epochs_max):

    # Training
    model.train()
    train_losses = []
    for i_batch, (data, data_lengths) in enumerate(tqdm(train_loader)):
        optimiser.zero_grad()
        data = data.to(device)       
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )

        loss = criterion(
            decoder_output.contiguous().view(-1, decoder_output.size(-1)),
            data.contiguous().view(-1)
            )
        loss /= len(data_lengths)
        loss.backward()
        optimiser.step()
        train_losses.append(loss.item())
        i_step += 1
        if i_step == n_steps_max and n_steps_max is not None:
            break

    # Validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        for i_batch, (data, data_lengths) in enumerate(val_loader):
            data = data.to(device)            
            encoder_embedding, decoder_output = model(
                data, data_lengths, data, data_lengths
                )

            loss = criterion(
                decoder_output.contiguous().view(-1,
                decoder_output.size(-1)), data.contiguous().view(-1)
                )
            loss /= len(data_lengths)
            val_losses.append(loss.item())
    
    print(
        "Epoch {}, train loss: {:.3f}, val loss: {:.3f}".format(
        i_epoch,
        np.mean(train_losses),
        np.mean(val_losses))
        )
    sys.stdout.flush()

    if i_step == n_steps_max and n_steps_max is not None:
        break

100%|██████████| 166/166 [00:04<00:00, 33.68it/s]


Epoch 0, train loss: 100.139, val loss: 94.218


100%|██████████| 166/166 [00:05<00:00, 32.39it/s]


Epoch 1, train loss: 93.460, val loss: 101.742


100%|██████████| 166/166 [00:05<00:00, 32.30it/s]

Epoch 2, train loss: 90.882, val loss: 92.639



100%|██████████| 166/166 [00:05<00:00, 32.70it/s]

Epoch 3, train loss: 88.192, val loss: 85.760



100%|██████████| 166/166 [00:04<00:00, 33.50it/s]


Epoch 4, train loss: 85.429, val loss: 86.196


100%|██████████| 166/166 [00:04<00:00, 33.68it/s]


Epoch 5, train loss: 81.994, val loss: 83.851


100%|██████████| 166/166 [00:04<00:00, 33.62it/s]


Epoch 6, train loss: 78.668, val loss: 78.533


100%|██████████| 166/166 [00:04<00:00, 33.94it/s]


Epoch 7, train loss: 75.795, val loss: 71.432


100%|██████████| 166/166 [00:04<00:00, 33.82it/s]


Epoch 8, train loss: 72.489, val loss: 76.813


  3%|▎         | 5/166 [00:00<00:05, 28.03it/s]


Epoch 9, train loss: 77.842, val loss: 60.722


In [18]:
# Examples without segmentation

# Apply to validation data
model.eval()
with torch.no_grad():
    for i_batch, (data, data_lengths) in enumerate(val_loader):
        data = data.to(device)
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )
        
        y, log_probs = model.decoder.greedy_decode(
            encoder_embedding,
            max_length=25,
            )
        x = data.cpu().numpy()
        
        for i_input in range(y.shape[0]):
            # Only print up to EOS symbol
            input_symbols = []
            for i in x[i_input]:
                if i == symbol_to_id[EOS_SYMBOL] or i == symbol_to_id[PAD_SYMBOL]:
                    break
                input_symbols.append(id_to_symbol[i])
            output_symbols = []
            for i in y[i_input]:
                if i == symbol_to_id[EOS_SYMBOL] or i == symbol_to_id[PAD_SYMBOL]:
                    break
                output_symbols.append(id_to_symbol[i])

            print("Input: ", "_".join(input_symbols))
            print("Output:", "_".join(output_symbols))
            print()
            
            if i_input == 10:
                break
        
        break

Input:  4_25_12_40_12_33_31_22_38_26_39_32_35_28_36_17_22_16_0_35_41_17_39_32_35_41_17_29_4
Output: 4_25_12_40_12_26_26_26_39_39_39_35_35_35_35_35_35_17_17_17_17_17_17_17_17

Input:  4_9_1_15_20_49_23_40_14_34_38_23_12_14_39_32_28_19_31_5_1_21_39_35_6_1_15_29
Output: 4_9_1_15_20_49_23_40_14_39_23_23_14_14_39_39_28_39_39_35_35_35_15_15_15

Input:  4_9_43_11_3_24_1_43_33_13_10
Output: 4_9_43_14_36_9_1_1_43_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10

Input:  24_1_21_0_43_40_12_26_20_22_16_37_41_24_1_21_49_38_13_4
Output: 4_9_1_43_45_12_26_22_16_16_16_16_16_21_21_38_33_13_13_13_4_4_4_4_4

Input:  4_0_28_23_33_31_3_16_37_25_12_26_20_22
Output: 4_0_28_19_31_31_16_37_37_12_12_20_22_16_16_16_16_16_16_16_37_37_12_12_12

Input:  4_32_35_28_19_33_39_8_38_34_37_2_33_10_8_38_26_5_39_41_15_20_22_16_44_0_28_23_40_14_36_34_38_20_39_32_28_42_36_34_41_1_15_20_34_44_35_30_15_21_44_37_25_40_10_45_25_33_31_22_16_38_26_29
Output: 4_0_35_28_19_31_3_8_38_38_26_34_38_0_38_38_26_38_44_44_44_44_34_44_44

In

## Segmentation

In [19]:
# Utterances for evaluation
n_eval_utterances = len(prepared_text)  # 1000  # 10000  # 1000
# eval_sentences = prepared_text[-n_eval_utterances:]  # val sentences
# eval_utterances = list(utterances)[-n_eval_utterances:]
eval_sentences = prepared_text[:n_eval_utterances]
eval_utterances = list(utterances)[:n_eval_utterances]

In [20]:
# Embed segments

# Random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Data
sentences = eval_sentences
# sentences = cur_val_sentences
interval_dataset = datasets.SentenceIntervalDataset(
    sentences,
    text_to_id,
    join_char="_"
    )
segment_loader = DataLoader(
    interval_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    collate_fn=datasets.pad_collate,
    drop_last=False
    )

# Apply model to data
model.decoder.teacher_forcing_ratio = 1.0
model.eval()
rnn_losses = []
lengths = []
eos = []
with torch.no_grad():
    for i_batch, (data, data_lengths) in enumerate(tqdm(segment_loader)):
        data = data.to(device)
        
        encoder_embedding, decoder_output = model(
            data, data_lengths, data, data_lengths
            )

        for i_item in range(data.shape[0]):
            item_loss = criterion(
                decoder_output[i_item].contiguous().view(-1,
                decoder_output[i_item].size(-1)),
                data[i_item].contiguous().view(-1)
                )
            rnn_losses.append(item_loss.cpu().numpy())
            lengths.append(data_lengths[i_item])

100%|██████████| 55684/55684 [04:31<00:00, 205.36it/s]


In [25]:
# Segment

dur_weight = 3.0

i_item = 0
losses = []
cur_segmented_sentences = []
for i_sentence, intervals in enumerate(tqdm(interval_dataset.intervals)):
    
    # Costs for segment intervals
    costs = np.inf*np.ones(len(intervals))
    i_eos = intervals[-1][-1]
    for i_seg, interval in enumerate(intervals):
        if interval is None:
            continue
        i_start, i_end = interval
        dur = i_end - i_start
        assert dur == lengths[i_item]
        eos = (i_end == i_eos)  # end-of-sequence
        
        # Chorowski
        costs[i_seg] = (
            rnn_losses[i_item]
            + dur_weight*neg_chorowski(dur)
            )
        
#         # Gamma
#         costs[i_seg] = (
#             rnn_losses[i_item]
#             + dur_weight*neg_log_gamma(dur)
#             + np.log(np.sum(gamma_cache**dur_weight))
#             )
        
#         # Poisson
#         costs[i_seg] = (
#             rnn_losses[i_item]
#             + neg_log_poisson(dur)
#             )

#         # Histogram
#         costs[i_seg] = (
#             rnn_losses[i_item]
#             + dur_weight*(neg_log_hist(dur))
#             + np.log(np.sum(histogram**dur_weight))
#             )
    
#         # Sequence boundary
#         alpha = 0.3  # 0.3  # 0.9
#         if eos:
#             costs[i_seg] += -np.log(alpha)
#         else:
#             costs[i_seg] += -np.log(1 - alpha)
# #             K = 5000
# #             costs[i_seg] += -np.log((1 - alpha)/K)

        # Temp
#         if dur > 10 or dur <= 1:
#             costs[i_seg] = +np.inf
        i_item += 1
    
    # Viterbi segmentation
    n_frames = len(interval_dataset.sentences[i_sentence])
    summed_cost, boundaries = viterbi.custom_viterbi(costs, n_frames)
    losses.append(summed_cost)
    
    reference_sentence = sentences[i_sentence]
    segmented_sentence = get_segmented_sentence(
            interval_dataset.sentences[i_sentence],
            boundaries
            )
    cur_segmented_sentences.append(segmented_sentence)
#     # Print examples of the first few sentences
#     if i_sentence < 10:
#         print(reference_sentence)
#         print(segmented_sentence)
#         print()
    
print("NLL: {:.4f}".format(np.sum(losses)))

100%|██████████| 5285/5285 [00:03<00:00, 1450.33it/s]

NLL: -180370.4314





## Evaluation

In [26]:
print(cur_segmented_sentences[0])

# # To evaluate gold segmentation:
# cur_segmented_sentences = prepared_text_gold[:n_eval_utterances]
# print(cur_segmented_sentences[0])

4_37_12_18_39_8_28_18_24_1_18_20 49_36_24_1_31_45_2_13_3


In [27]:
# Convert segmentation to intervals
segmentation_interval_dict = {}
for i_utt, utt_key in tqdm(enumerate(eval_utterances)):
    words_segmented = cur_segmented_sentences[i_utt].split(" ")
    word_start = 0
    word_label = ""
    i_word = 0
    segmentation_interval_dict[utt_key] = []
    for (phone_start, phone_end,
            phone_label) in phoneseg_interval_dict[utt_key]:
        word_label += phone_label + "_"
        if words_segmented[i_word] == word_label[:-1]:
            segmentation_interval_dict[utt_key].append((
                word_start, phone_end, word_label[:-1]
                ))
            word_label = ""
            word_start = phone_end
            i_word += 1

#     if i_utt < 10:
#         print(segmentation_interval_dict[utt_key])
#         print(word_ref_interval_dict[utt_key])
#         print()        

5285it [00:00, 84995.77it/s]


In [28]:
# Write intervals to a directory
output_tag = "wordseg_segaernn_{}".format(seg_tag.replace("phoneseg_", ""))
output_dir = (
    Path("../../vqwordseg/exp")/vq_model/dataset/split/output_tag/"intervals"
    )
output_dir.mkdir(exist_ok=True, parents=True)
print(f"Writing to: {output_dir}")
for utt_key in tqdm(segmentation_interval_dict):
    with open((output_dir/utt_key).with_suffix(".txt"), "w") as f:
        for (i_segment, (start, end, label)) in enumerate(segmentation_interval_dict[utt_key]):
            f.write(f"{start:d} {end:d} {label}_\n")

 20%|██        | 1059/5285 [00:00<00:00, 8904.23it/s]

Writing to: ../../vqwordseg/exp/cpc_big/zs2017_zhtmp/train/wordseg_segaernn_dp_penalized/intervals


100%|██████████| 5285/5285 [00:00<00:00, 9877.93it/s]
