In [1]:
"""Some utility functions."""

import os
import sys
import tqdm
import random
import importlib
import subprocess
import config
import numpy as np
from typing import List, Union, Tuple, Any


def get_files_rec(path: str, forbidden=[], must_contain=[]) -> List[str]:
    def _inner(path: str):
        files = []
        parent = path
        for path in os.listdir(parent):
            path = os.path.join(parent, path)
            if os.path.isdir(path):
                files.extend(_inner(path))
            else:
                files.append(path)
        return files
    files = _inner(path)
    files = [file for file in files if all(forbid not in file for forbid in forbidden)]
    if len(must_contain):
        files = [file for file in files if any(must in file for must in must_contain)]
    files.sort()
    random.seed(config.random_seed)
    random.shuffle(files)
    return files

def delete_template(tmpl: str):
    i = 0
    while os.path.exists(tmpl.replace('$', str(i))):
        os.remove(tmpl.replace('$', str(i)))
        i += 1

def bar(it, *args, **kwargs) -> tqdm.tqdm:
    TQDM_DEFAULT_SETTINGS = {
        'colour': 'GREEN',
        'file': sys.stdout,
    }
    kwargs = {**TQDM_DEFAULT_SETTINGS, **kwargs}
    try:
        if kwargs.get("total") is None:
            kwargs["total"] = len(it)
    except:
        pass
    return tqdm.tqdm(it, *args, **kwargs)

def load_wordvecs(path: str, load_embs=False) -> Union[dict, Tuple[dict, Any]]:
    """A .vec word embeddings file to load the vocabulary and their embeddings from."""
    npz_path = path.removesuffix(".vec") + ".npz"
    if not os.path.exists(npz_path):
        print("\nWord embeddings are not prepared! Preparing them now.")
        print("This will take some time...")
        vec2npz(path)
    word2vec = np.load(npz_path)
    words = dict((w, i) for i, w in enumerate(word2vec["words"]))
    if not load_embs:
        return words
    # Loading the embeddings takes a lot of time.
    return words, word2vec["embeddings"]

def vec2npz(path: str, extra_vocab=[]):
    """Converts a .vec embeddings to vocabulary and embeddings lists and stores
    them as a compressed numpy array on the same path with .npz extension.
    """
    with open(path) as f:
        w_count, dim = f.readline().split()
        w_count, dim = int(w_count), int(dim)
        total_count = (w_count + len(extra_vocab))
        words = [""] * total_count
        embeddings = np.zeros((total_count, dim), dtype=np.float32)
        for i, line in enumerate(bar(f, total=w_count)):
            word, emb = line.split(maxsplit=1)
            words[i] = word
            embeddings[i] = np.array(emb.split(), dtype=np.float32)
        # Actual count can be less than the specified count in case of the vector file being trimmed.
        actual_count = i + 1

    np.random.seed(config.random_seed)
    for i, word in enumerate(extra_vocab):
        words[actual_count + i] = word
        # This generates a random 300d array of values in range [-0.99, 0.99]
        embeddings[actual_count + i] = 1.98 * np.random.rand(300).astype(np.float32) - 0.99
    actual_count += len(extra_vocab)

    # Convert the words array to a numpy array before storing.
    words = np.array(words)

    # Squeeze to the actual count.
    embeddings = embeddings[:actual_count]
    words = words[:actual_count]
    print("Storing the prepared word embeddings in a .npz file.")
    npz_path = path.removesuffix(".vec") + ".npz"
    np.savez(npz_path, embeddings=embeddings, words=words)
    print("Prepared word embeddings are now stored in: ", os.path.abspath(npz_path))

def trim_vocab_from(path: str, word2vec_file_path: str) -> str:
    """Trims `word2vec_file_path` to a set that has only the words found in files in `path`.
    The word2vec model we use have about one million unique words. But in spoken and written language,
    people use less than 1% of that (10k words).
    This helps in memory usage.
    """
    vocab, embeddings = load_wordvecs(word2vec_file_path, load_embs=True)
    vocab_set = set(vocab.keys())
    words_in_files = set()
    files = get_files_rec(path, must_contain=[".preprocessed"])
    for file in files:
        words = open(file).read().split()
        words_in_files.update(words)
    # Keep only the vocabs which appeared in the files we have.
    vocab_set = vocab_set.intersection(words_in_files)
    vocab = dict((word, vocab[word]) for word in vocab_set)
    new_vocab = [""] * len(vocab)
    new_embeddings = np.zeros((len(new_vocab), embeddings.shape[1]), dtype=np.float32)
    # Create the new vocab and embeddings maps.
    for i, word in enumerate(vocab):
        new_vocab[i] = word
        new_embeddings[i] = embeddings[vocab[word]]
    # Store the trimmed vocab.
    npz_path = word2vec_file_path.removesuffix(".vec") + "-trimmed.npz"
    np.savez(npz_path, embeddings=new_embeddings, words=new_vocab)
    print(f"Trimmed word embeddings (size={len(new_vocab)}) are now stored in: ", os.path.abspath(npz_path))
    return npz_path.removesuffix(".npz") + ".vec"

def make_ascii(word: str) -> str:
    return bytes(word, encoding='utf-8').decode('ascii', 'ignore').strip()

def get_sub(link: str, fname: str):
    """Downloads the auto-generated subtitles from the youtube link `link` and persists them in `file`."""
    subprocess.run(f"yt-dlp --write-auto-sub --skip-download \"{link}\" -o {fname}", shell=True).check_returncode()
    subprocess.run(f"ffmpeg -y -i {fname}.en.vtt {fname}.srt", shell=True).check_returncode()

    with open(f"{fname}.srt") as file:
        lines = file.readlines()
        good_lines = []
        i = 0
        while i != len(lines):
            try:
                int(lines[i])
                if lines[i + 1].__contains__('-->'):
                    i += 2
            except:
                pass
            good_lines.append(lines[i])
            i += 1
        with open(f"{fname}.txt", 'w') as file2:
            for line in good_lines:
                file2.write(line)

    # SRT files contain duplicate lines for some reason, delete them.
    with open(f"{fname}.txt") as file:
        lines = file.readlines()
        lines = [line.strip() for line in lines]
        lines = [line for line in lines if line != '']
        last_line = None
        non_dup_lines = []
        for line in lines:
            if last_line == line:
                continue
            non_dup_lines.append(line)
            last_line = line

    with open(f"{fname}.asr", "w") as file:
        file.write(' '.join(non_dup_lines))

    # Remove the intermediate files.
    os.remove(f"{fname}.txt")
    os.remove(f"{fname}.srt")
    os.remove(f"{fname}.en.vtt")

In [2]:
"""Preprocess the wiki data by removing the stop words, punctuation and
irrelevant parts to make it look more like an ASR output."""

import os
import nltk
import utils
import config
import string
from typing import Tuple, List

useless_tags = {"***LIST***"}
try:
    stop_words = set(nltk.corpus.stopwords.words('english')).union(useless_tags)
except:
    nltk.download('stopwords')
    stop_words = set(nltk.corpus.stopwords.words('english')).union(useless_tags)

def recoverable_clean_section(section: str, vocab: dict) -> Tuple[str, List[str]]:
    """This is similar to clean_section but it stores the information about the original
    structure of the text stream so we can recover the correct splits in the input.
    """
    # Remove any punctuation.
    for p in string.punctuation:
        section = section.replace(p, ' ')
    words = [w.lower().strip() for w in section.split()]
    words = [w for w in words if w]
    original, clean_sec = [], []
    for word in words:
        ascii_word = utils.make_ascii(word)
        # Remove stop words, unknown words and non-ascii words.
        if not ascii_word or ascii_word in stop_words or ascii_word not in vocab:
            # 0 indicates that we missed this word.
            original.append(word)
        else:
            # 1 indicates that we took this word into account.
            original.append(word)
            clean_sec.append(ascii_word)
    return ' '.join(clean_sec), original

def clean_section(section: str, vocab: dict) -> str:
    return recoverable_clean_section(section, vocab)[0]

def process_doc(document: str, vocab: dict):
    lines = [l for l in document.split('\n') if l]

    # Divide the document into sections.
    sections, i = [], 0
    while i < len(lines):
        if lines[i].startswith(config.section_start):
            i += 1
            start = i
            while i < len(lines) and not lines[i].startswith(config.section_start):
                i += 1
            end = i
            sections.append(' '.join(lines[start:end]))
        else:
            i += 1

    # Rejoin the sections together.
    document = []
    for section in sections:
        section = clean_section(section, vocab)
        # Omit sections that are very short (doesn't meet the min words per sec requirement).
        if len(section.split()) >= config.min_words_per_section:
            document.append(section)
    return "\n\n".join(document)

def preprocess_wiki(wiki_path):
    # Only grab files with no extension (no .preprocessed/.tf/.anything files).
    files = utils.get_files_rec(wiki_path, forbidden=['.'])
    vocab = utils.load_wordvecs(config.word2vec_file_path)
    print("\nPreprocessing the files in", os.path.abspath(wiki_path))
    for file in utils.bar(files):
        document = open(file).read()
        open(file + ".preprocessed", 'w').write(process_doc(document, vocab))


In [3]:
import os
import utils
import config

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

NEW_TOPIC_MARKER = "NEW_TOPIC_MARKER"


def read_files(files):
    for file in files:
        yield open(file).read().replace("\n\n", " " + NEW_TOPIC_MARKER + " ")

def segment_generator(reader):
    segment = []
    for doc in reader:
        # A new document is considered a topical change.
        had_marker = True
        for word in doc.split():
            if word == NEW_TOPIC_MARKER:
                had_marker = True
            else:
                segment.append(word)
                if len(segment) == config.segment_length:
                    yield (segment, had_marker)
                    # Reset the segment and the marker.
                    segment, had_marker = [], False

def write_segment(segment, writer, vocab):
    words, topic_boundary = segment
    # All the words should be in our vocabulary since we eliminated unknown words while preprocessing.
    tokens = [vocab[w] for w in words]
    record = tf.train.SequenceExample()
    record.context.feature["topic_bound"].int64_list.value.append(topic_boundary)
    tok_seq = record.feature_lists.feature_list["token_sequence"]
    for token in tokens:
        tok_seq.feature.add().int64_list.value.append(token)
    writer.write(record.SerializeToString())

def record_gen_predict(asr: str, record_path_template: str):
    """Same as record_gen but works on only one input document.
    More suitable for predicting.
    """
    vocab = utils.load_wordvecs(config.word2vec_file_path)
    writer = tf.io.TFRecordWriter(record_path_template.replace('$', '0'))
    for segment in segment_generator([asr]):
        write_segment(segment, writer, vocab)
    print("Generated records are stored in", os.path.abspath(record_path_template))

def record_gen(input_dir_path: str, record_path_template: str):
    utils.delete_template(record_path_template)
    files = utils.get_files_rec(input_dir_path, must_contain=[".preprocessed"])
    if len(files) == 0:
        print("WARNING: There are no .preprocessed files in", input_dir_path)
    print("\nGenerating TF records for the preprocessed files in", os.path.abspath(input_dir_path))
    # Don't load all the files in memory at once, create a lazy file reader instead.
    reader = read_files(utils.bar(files))
    vocab = utils.load_wordvecs(config.word2vec_file_path)
    for index, segment in enumerate(segment_generator(reader)):
        # Refresh the writer to parallelize the IO as much as possible when reading.
        if index % config.segments_per_tfrecord == 0:
            writer = tf.io.TFRecordWriter(record_path_template.replace('$', str(index // config.segments_per_tfrecord)))
        write_segment(segment, writer, vocab)
    print("Generated records are stored in", os.path.abspath(record_path_template))

def input_fn(record_path_template: str):
    def make_dataset(deserializer):
        paths, index = [], 0
        while os.path.exists(record_path_template.replace('$', str(index))):
            paths.append(record_path_template.replace('$', str(index)))
            index += 1
        dataset = tf.data.TFRecordDataset(paths)
        dataset = dataset.repeat(config.epochs)
        dataset = dataset.map(deserializer, num_parallel_calls=128)
        # Create a sliding window over the data with a `snippet_stride` slider size.
        dataset = dataset.window(size=config.snippet_length, shift=config.snippet_stride, drop_remainder=True)
        # Join the windows back together. (This outputs every window batched instead of being a `_VariantDataset`).
        dataset = dataset.flat_map(lambda window: window.batch(config.snippet_length))
        # Batch every `config.batch_size` windows together as one training example.
        dataset = dataset.batch(config.batch_size, drop_remainder=True)
        # Buffer some batches to lower the IO latency.
        dataset = dataset.prefetch(20)
        return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()

    def deserialize_labels(sr):
        return tf.io.parse_single_sequence_example(serialized=sr, context_features={"topic_bound": tf.io.FixedLenFeature([], dtype=tf.int64)})[0]["topic_bound"]
    def deserialize_features(sr):
        return tf.io.parse_single_sequence_example(serialized=sr, sequence_features={"token_sequence": tf.io.FixedLenSequenceFeature([], dtype=tf.int64)})[1]["token_sequence"]
    def deserialize(serialized_record):
        context_parsed, sequence_parsed = tf.io.parse_single_sequence_example(
            serialized=serialized_record,
            context_features={"topic_bound": tf.io.FixedLenFeature([], dtype=tf.int64)},
            sequence_features={"token_sequence": tf.io.FixedLenSequenceFeature([], dtype=tf.int64)},
        )
        return (sequence_parsed["token_sequence"], context_parsed["topic_bound"])

    return make_dataset(deserialize_features), make_dataset(deserialize_labels)


def get_batch_count(record_path_template: str):
    paths, index = [], 0
    while os.path.exists(record_path_template.replace('$', str(index))):
        paths.append(record_path_template.replace('$', str(index)))
        index += 1
    # Assuming every tf record is full. This is a pessimistic estimation.
    # Also assuming `config.segments_per_tfrecord` have changed sense the record generation.
    segment_count = index * config.epochs * config.segments_per_tfrecord
    snippet_count = segment_count // config.snippet_stride
    batch_count = snippet_count // config.batch_size
    return batch_count


In [4]:
import os
import sys
import utils
import config
import numpy as np
import tensorflow as tf
from transformer.model import transformer, model_utils


def model_fn(features, labels, mode, params):
    prints = []
    def out(*args, **kwargs):
        prints.append(tf.print(args, kwargs, output_stream=sys.stdout))
    def uniform(s):
        return tf.initializers.GlorotUniform()(s)

    predicting = mode == tf.estimator.ModeKeys.PREDICT
    if not predicting:
        TT_PARAMS = config.token_transformer_training_params
        ST_PARAMS = config.sentence_transformer_training_params
    else:
        TT_PARAMS = config.token_transformer_prediction_params
        ST_PARAMS = config.sentence_transformer_prediction_params

    embeddings = tf.convert_to_tensor(params["embeddings"])
    snippet_batch = tf.reshape(features, [config.batch_size, config.snippet_length, config.segment_length])

    # The distributions used for positional embeddings. Each, segment & snippet level positions take half the `config.positional_embeddings_length`.
    token_positional_embeddings_dist = tf.Variable(uniform([config.segment_length, config.positional_embeddings_length // 2]))
    sentence_positional_embeddings_dist = tf.Variable(uniform([config.snippet_length, config.positional_embeddings_length // 2]))

    # Arrays with the same shape as `snippet_batch` that has the nominal positions of the token on a segment level and snippet level.
    token_nominal_positions = np.tile(np.arange(config.segment_length), [config.batch_size, config.snippet_length, 1])
    sentence_nominal_positions = np.tile(np.tile(np.arange(config.snippet_length), [config.segment_length, 1]).T, [config.batch_size, 1, 1])

    # Look up the word and positional embeddings and concatenate them together as the final embeddings.
    snippet_batch_embeddings = tf.nn.embedding_lookup(embeddings, snippet_batch)
    token_positional_embeddings = tf.nn.embedding_lookup(token_positional_embeddings_dist, token_nominal_positions)
    sentence_positional_embeddings = tf.nn.embedding_lookup(sentence_positional_embeddings_dist, sentence_nominal_positions)
    snippet_batch_embeddings = tf.concat([snippet_batch_embeddings, token_positional_embeddings, sentence_positional_embeddings], axis=3)

    # Token-level transformer (TT):
    hidden_size = embeddings.shape[1] + config.positional_embeddings_length
    tt_input = tf.reshape(snippet_batch_embeddings, [config.batch_size * config.snippet_length, config.segment_length, hidden_size])
    TT_PARAMS.update({"hidden_size": hidden_size})
    tt_trans = transformer.EncoderStack(TT_PARAMS, mode)
    # Since we don't have any padding in our input, the padding is gonna be all zeros.
    attention_padding = tf.zeros([config.batch_size * config.snippet_length, config.segment_length])
    attention_bias = model_utils.get_padding_bias(attention_padding, padding_value=-1)
    tt_output = tt_trans(tt_input, attention_bias, attention_padding)
    # We will use the first and last token to represent the sentence.
    sentence_embeddings = tf.concat([tt_output[:, 0, :], tt_output[:, -1, :]], axis=1)

    # Sentence-level transformer (ST):
    hidden_size = sentence_embeddings.shape[1]
    st_input = tf.reshape(sentence_embeddings, [config.batch_size, config.snippet_length, hidden_size])
    ST_PARAMS.update({"hidden_size": hidden_size})
    st_trans = transformer.EncoderStack(ST_PARAMS, mode)
    # No padding here as well.
    attention_padding = tf.zeros([config.batch_size, config.snippet_length])
    attention_bias = model_utils.get_padding_bias(attention_padding, padding_value=-1)
    st_output = st_trans(st_input, attention_bias, attention_padding)

    # Segmentation classifier:
    seg_classifier_w = tf.Variable(uniform([st_output.shape[2], 2]))
    seg_classifier_b = tf.Variable(uniform([2]))
    seg_probabilities = tf.nn.softmax(tf.add(tf.tensordot(st_output, seg_classifier_w, axes=[[2], [0]]), seg_classifier_b))

    if not predicting:
        # Prepare segment labels:
        labels = tf.reshape(labels, [config.batch_size, config.snippet_length])
        label_2d_fn = lambda x: tf.cond(tf.equal(x, 1), lambda: tf.constant([0., 1.]), lambda: tf.constant([1., 0.]))
        segment_labels = tf.map_fn(lambda x: tf.map_fn(label_2d_fn, x, dtype=tf.float32), labels, dtype=tf.float32)

        # Define the loss:
        segmentation_loss = -1 * tf.reduce_sum(tf.multiply(tf.math.log(seg_probabilities), segment_labels))
        tf.summary.scalar("Segmentation Loss", segmentation_loss)

        optimizer_seg = tf.compat.v1.train.AdamOptimizer(learning_rate=config.learning_rate)
        # Make this statement depend on the prints we have above to not optimize them away.
        with tf.control_dependencies(prints):
            train_op = optimizer_seg.minimize(segmentation_loss, tf.compat.v1.train.get_global_step())

        print("Model defined.")
        return tf.estimator.EstimatorSpec(mode, loss=segmentation_loss, train_op=train_op)
    else:
        return tf.estimator.EstimatorSpec(mode, predictions=seg_probabilities)


def train(tfrecord_tmpl):
    print("Loading the word2vec model...")
    _, embeddings = utils.load_wordvecs(config.word2vec_file_path, load_embs=True)
    model_conf = tf.estimator.RunConfig(model_dir=config.model_dir)
    estimator = tf.estimator.Estimator(model_fn=model_fn, config=model_conf, params={"embeddings": embeddings})
    print("Training...")
    estimator.train(input_fn=lambda: input_fn(tfrecord_tmpl))

In [5]:
import tensorflow as tf
print(tf.test.gpu_device_name())

/device:GPU:0


# Perform preprocessing (ASR-Like) and tf record generation

In [6]:
importlib.reload(config)

input_dirs = ["data/wiki_727K/train", "data/wiki_727K/dev", "data/wiki_727K/test"]
input_dirs = []
for input_dir in input_dirs:
    tfrecord_tmpl = os.path.join(input_dir, config.tfrecord_tmpl)
    # Preprocess the documents.
    preprocess_wiki(input_dir)
    # Generate the tensorflow records.
    record_gen(input_dir, tfrecord_tmpl)
    print(f"Number of {input_dir} examples =", get_batch_count(tfrecord_tmpl))

# Training

In [7]:
importlib.reload(config)

# Set the input directory.
input_dir = "data/wiki_727K/train"
tfrecord_tmpl = os.path.join(input_dir, config.tfrecord_tmpl)

print("Number of training examples =", get_batch_count(tfrecord_tmpl))
print("Loading the word2vec model...")
_, embeddings = utils.load_wordvecs(config.word2vec_file_path, load_embs=True)
print("Loaded wordvecs.")

model_conf = tf.estimator.RunConfig(model_dir=config.model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=model_conf, params={"embeddings": embeddings})

print("Training...")
estimator.train(input_fn=lambda: input_fn(tfrecord_tmpl))

Number of training examples = 232857
Loading the word2vec model...
Loaded wordvecs.
Instructions for updating:
Use tf.keras instead.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using config: {'_model_dir': 'data/model/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief'

KeyboardInterrupt: 

# Evaluation

In [None]:
importlib.reload(config)

input_dir = "data/wiki_727K/dev"
input_dir = "data/ext"
tfrecord_tmpl = os.path.join(input_dir, config.tfrecord_tmpl)

print("Number of evaluation examples =", get_batch_count(tfrecord_tmpl))
print("Loading the word2vec model...")
_, embeddings = utils.load_wordvecs(config.word2vec_file_path, load_embs=True)
model_conf = tf.estimator.RunConfig(model_dir=config.model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=model_conf, params={"embeddings": embeddings})

print("Evaluating...")
estimator.evaluate(input_fn=lambda: input_fn(tfrecord_tmpl))

# Prediction

In [None]:
import tempfile
import shutil
import predict
importlib.reload(predict)

input_file = "100368.preprocessed"
tmp_dir = tempfile.mkdtemp()
print("Working inside", tmp_dir)
shutil.copyfile(input_file, os.path.join(tmp_dir, os.path.basename(input_file)))
predict.predict(tmp_dir)
shutil.rmtree(tmp_dir)