In [200]:
import os
import math
import datetime
from typing import List, Tuple, Dict, Optional, Union, Any

import numpy as np
import babel
import nlpaug.augmenter.char as nac
from babel.dates import format_date, format_datetime
from nltk.tokenize import sent_tokenize

from tensorflow.keras import backend as K
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import (
    LSTM,
    Bidirectional,
    Dense,
    Dropout,
    Input,
    RepeatVector,
    TimeDistributed,
    Embedding,
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

os.chdir("/Users/danny/Desktop/lazydate")

np.random.seed(42)

In [4]:
## Read wikitext-103
with open('data/wiki.train.raw', 'r') as f:
    wikitext = f.read()

In [5]:
wiki_sentences = sent_tokenize(wikitext[:10000000])

In [6]:
np.random.choice(wiki_sentences)

'Irish rugby has become increasingly competitive at both the international and provincial levels since the sport went professional in 1994 .'

In [115]:
day_formats = ["d", "dd"]
month_formats = ["M", "MM", "MMM", "MMMM", "MMMM", "L", "LL", "LLL", "LLLL", "LLLL"]
year_formats = ["yy", "yyyy" ]

second_formats = ["s", "ss"]
minute_formats = ["m", "mm"]
hour_formats = ["h", "hh", "H", "HH"]
timezone_formats = ["", "", "", "", "", "z", "zz", "zzz", "zzzz"]
time_separators = [":"]

separator_frequency = {
    ".": 0.1, 
    "/": 0.15, 
    "-": 0.15, 
    "''": 0.1,
    " ": 0.5, 
}

built_in_formats = ["short", "medium", "long", "full"]

locales = babel.localedata.locale_identifiers()
locales = [l for l in locales if "en_" in l]


def random_date(n_years: int = 100) -> Tuple[datetime.datetime, Dict[str, int]]:
    start_date = datetime.datetime(1900, 1, 1, 0, 0, 0)
    gen_dict = {
        "days": np.random.randint(0, n_years * 265),
        "hours": np.random.randint(0, 24),
        "minutes": np.random.randint(0, 60),
        "seconds": np.random.randint(0, 60),
    }
    
    date = start_date + datetime.timedelta(**gen_dict)
    return date, gen_dict


def random_format(date: datetime.datetime) -> Tuple[str, Dict[str, str]]:
    possible_separators = list(separator_frequency.keys())
    
    if date.year >= datetime.datetime.now().year + 1:
        year_format = np.random.choice(year_formats)
    else:
        year_format = "yyyy"
    
    append_time = np.random.rand() <= 0.5
    gen_dict = {
        "day": np.random.choice(day_formats),
        "month": np.random.choice(month_formats),
        "year": year_format,
        "separator": np.random.choice(
            possible_separators, p=list(separator_frequency.values())
        ),
        "append_time": append_time,
    }
    if append_time:
        time_gen_dict = {
            "second": np.random.choice(second_formats),
            "minute": np.random.choice(minute_formats),
            "hour": np.random.choice(hour_formats),
            "timezone": np.random.choice(timezone_formats),
            "time_separator": np.random.choice(time_separators)
        }
    else:
        time_gen_dict = {k: "" for k in ["second", "minute", "hours", "timezone", "time_separator"]}
    gen_dict.update(time_gen_dict)
    
    sep = gen_dict["separator"]
    if sep != "''" and gen_dict["year"] == "yy":
        if np.random.random() <= 0.5:
            gen_dict["year"] = "''" + gen_dict["year"]
            
    format_date_str = f"{gen_dict['day']}{sep}{gen_dict['month']}{sep}{gen_dict['year']}"
    format_time_str = ""
    
    if append_time:
        sep = gen_dict["time_separator"]
        format_time_str = f" {gen_dict['hour']}{sep}{gen_dict['minute']}"
        if np.random.random() <= 0.5:
            format_time_str += f"{sep}{gen_dict['second']}"
        if np.random.random() <= 0.5:
            format_time_str += f" a"  # AM / PM
        if np.random.random() <= 0.5:
            format_time_str += f" {gen_dict['timezone']}"    
            
    format_str = format_date_str + format_time_str
    gen_dict["format_str"] = format_str
    return format_str, gen_dict


def get_random_wiki_sentence() -> str:
    idx = np.random.randint(0, len(wiki_sentences))
    return wiki_sentences[idx]


def random_noise_dict(
    date: datetime.datetime, format_dict: Dict[str, str]
) -> Dict[str, str]:

    append_day_suffix = format_dict["day"] == "dd" and np.random.random() <= 0.5
    place_in_sentence = np.random.random() <= 0.5
    
    gen_dict = {
        "locale": np.random.choice(locales),
        "append_day_suffix": append_day_suffix,
        "aug_char_action": np.random.choice(["insert", "substitute"]),
        "place_in_sentence": place_in_sentence,
        "sentence": get_random_wiki_sentence() if place_in_sentence else "",
    }
    
    day_suffix = ""
    if append_day_suffix:
        if date.day in [1, 21, 31]:
            day_suffix = "st"
        elif date.day in [2, 22]:
            day_suffix = "st"
        elif date.day in [3, 23]:
            day_suffix = "rd"
        else:
            day_suffix = "th"
    gen_dict["day_suffix"] = day_suffix
    
    return gen_dict


def put_datestr_in_sentence(datestr: str, sentence: str):
    split_sentence = sentence.split(" ")    
    idx = np.random.randint(0, len(split_sentence))
    split_sentence[idx] = datestr
    return " ".join(split_sentence)


def apply_noise(datestr: str, format_dict: Dict[str, str], noise_dict: Dict[str, Any]) -> str:
    out = datestr
    sep = format_dict["separator"]
    sep = sep[0] if len(sep) > 1 else sep
    
    date_parts = datestr.split(sep)
    
    if noise_dict["append_day_suffix"]:
        date_parts[0] = date_parts[0] + noise_dict["day_suffix"]
        
    # Add spelling mistake to month name
    if len(format_dict["month"]) > 2 and np.random.random() <= 0.3:
        aug = nac.RandomCharAug(
            action=noise_dict["aug_char_action"],
            aug_char_min=1, 
            aug_char_max=1,
        )
        date_parts[1] = aug.augment(date_parts[1])
        
    out = f"{sep}".join(date_parts)
    
    if noise_dict["place_in_sentence"]:
        out = put_datestr_in_sentence(out, noise_dict["sentence"])
    
    return out


def generate_date(
    no_date_prob: float = 0.1
) -> Tuple[str, datetime.datetime, Dict[str, Any]]:
    date, date_gen_dict = random_date()
    format_str, format_gen_dict = random_format(date)
    noise_gen_dict = random_noise_dict(date, format_gen_dict)

    datestr = format_datetime(
        date, format=format_str, locale=noise_gen_dict["locale"],
    )
    datestr = apply_noise(datestr, format_gen_dict, noise_gen_dict)
    
    gen_dict = date_gen_dict
    gen_dict.update(format_gen_dict)
    gen_dict.update(noise_gen_dict)
    gen_dict["no_date"] = False
    
    # Example with no date
    if np.random.random() <= no_date_prob:
        date = None
        datestr = get_random_wiki_sentence()
        gen_dict["no_date"] = True
    
    return datestr, date, gen_dict

In [116]:
%%time
datestr, date, gen_dict = generate_date()
print(f"Date string: {datestr}\n")
print(f"Correct Date: {date}\n")
# print(gen_dict)

Date string: 20th.11.1963

Correct Date: 1963-11-20 01:59:20

CPU times: user 1.46 ms, sys: 1.5 ms, total: 2.95 ms
Wall time: 9.71 ms


In [117]:
date.strftime(format="%Y%m%d")

'19631120'

## Vectorizer and generator

In [195]:
LETTERS = "abcdefghijklmnopqrstuvwxyz"
DIGITS = "0123456789"
SYMBOLS = "£&()[]+-/*;:@_\\\"'#" + "€$%!?,. "
VOCABULARY = LETTERS + DIGITS + SYMBOLS

MAX_SEQUENCE_LEN = 150
UNK_TOKEN = "<unk>"

class CharVectorizer:
    def __init__(
        self, vocabulary: str, max_sequence_len: int = MAX_SEQUENCE_LEN
    ):
        self.max_sequence_len = max_sequence_len
        self.encoder: Dict[str, int] = {c: idx for idx, c in enumerate(list(vocabulary))}
        self.encoder[UNK_TOKEN] = len(self.encoder)

    @property
    def vocabulary(self):
        return sorted(list(self.encoder.keys()))

    def transform(self, inputs: List[str]) -> np.ndarray:
        outputs = [self._get_char_indices_for_word(s) for s in inputs]
        outputs = np.array(outputs)
        return outputs

    def _get_char_indices_for_word(self, text: str) -> np.ndarray:
        next_arr = np.zeros([self.max_sequence_len], dtype=np.int32)

        for idx, token in enumerate(text):
            if idx < self.max_sequence_len:  # truncate end of sentence if too long
                if token in self.encoder:
                    vocab_idx = self.encoder[token]
                else:
                    vocab_idx = self.encoder[UNK_TOKEN]
                next_arr[idx] = vocab_idx
        return next_arr

In [201]:
class DataGenerator(Sequence):
    def __init__(self, batch_size=32, n_examples=50000):
        self.batch_size = batch_size
        self.n_examples = n_examples
        self.input_vectorizer = CharVectorizer(vocabulary=VOCABULARY)
        self.output_vectorizer = CharVectorizer(vocabulary=DIGITS, max_sequence_len=8)
        
    def __len__(self):
        return int(math.ceil(self.n_examples / self.batch_size))
    
    @property
    def input_sequence_len(self):
        return self.input_vectorizer.max_sequence_len
    
    @property
    def input_vocab_size(self):
        return len(self.input_vectorizer.vocabulary)
    
    @property
    def output_sequence_len(self):
        return self.output_vectorizer.max_sequence_len
    
    @property
    def output_vocab_size(self):
        return len(self.output_vectorizer.vocabulary)
        
    def generate_string_batch(self) -> Tuple[List[str], List[str]]:
        input_strings: List[str] = []
        output_strings: List[str] = []
            
        for _ in range(self.batch_size):
            datestr, date, gen_dict = generate_date()
            if date:
                output_datestr = date.strftime(format="%Y%m%d")
            else:
                output_datestr = "".join([UNK_TOKEN] * self.output_vectorizer.max_sequence_len)
            input_strings.append(datestr)
            output_strings.append(output_datestr)
            
        return input_strings, output_strings
        
    def __getitem__(self, idx: int):
        input_strings, output_strings = self.generate_string_batch()
        inputs = {"datestr": self.input_vectorizer.transform(input_strings)}
        outputs = {"output_datestr": self.output_vectorizer.transform(output_strings)}
        return inputs, outputs

In [202]:
generator = DataGenerator()

In [203]:
%%time
inputs, outputs = generator.__getitem__(2)

CPU times: user 12.7 ms, sys: 1.21 ms, total: 13.9 ms
Wall time: 14.4 ms


## Model

In [204]:
def all_acc(y_true, y_pred):
    return K.mean(
        K.all(
            K.equal(
                K.max(y_true, axis=-1),
                K.cast(K.argmax(y_pred, axis=-1), K.floatx())
            ),
            axis=1)
    )


def lstm_encoder_decoder(
    input_sequence_len: int,
    input_vocab_size: int,
    output_sequence_len: int,
    output_vocab_size: int,
    embedding_dim: int = 64,
    lstm_hidden_dim: int = 64,
    learning_rate: float = 1e-3,
):
    # Encoder
    _input = Input(shape=(input_sequence_len,), dtype="int32", name="datestr")
    embedding = Embedding(output_dim=embedding_dim, input_dim=input_vocab_size, mask_zero=True)(_input)
    encoded = Bidirectional(LSTM(lstm_hidden_dim, return_sequences=False))(embedding)
    
    # Decoder
    repeated = RepeatVector(output_sequence_len)(encoded)
    decoded = LSTM(lstm_hidden_dim, return_sequences=True)(repeated)
    _output = TimeDistributed(Dense(output_vocab_size, activation='softmax'))(decoded)
    
    model = Model(inputs=[_input], outputs={"output_datestr": _output})
    optimizer = Adam(lr=learning_rate)
    model.compile(optimizer, loss='sparse_categorical_crossentropy', metrics=["accuracy"])
    return model

In [205]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=2, restore_best_weights=True
)

gen_train = DataGenerator()
gen_val = DataGenerator()

model = lstm_encoder_decoder(
    input_sequence_len=gen_train.input_sequence_len,
    input_vocab_size=gen_train.input_vocab_size,
    output_sequence_len=gen_train.output_sequence_len,
    output_vocab_size=gen_train.output_vocab_size,
)

history = model.fit(
    gen_train,
    epochs=10,
    callbacks=[early_stopping],
    validation_data=gen_val,
    max_queue_size=20,
    workers=2,
    use_multiprocessing=True,
)

Epoch 1/10
Epoch 2/10
  61/1563 [>.............................] - ETA: 3:25 - loss: 1.1508 - accuracy: 0.5702

Process Keras_worker_ForkPoolWorker-7:
Process Keras_worker_ForkPoolWorker-8:
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/pool.py", line 110, in worker
    task = get()
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 352, in get
    res = self._reader.recv_bytes()
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/usr/local/C

KeyboardInterrupt: 