In [1]:
import os

from pprint import pprint
from contextlib import redirect_stdout

import numpy as np
import tensorflow as tf

In [2]:
# --------------------------------------- DATA FUNCTIONS -----------------------------------

RAW_STRINGS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
STRINGS = RAW_STRINGS + RAW_STRINGS + RAW_STRINGS

# CHAR_TO_INT = dict(zip([item for item in RAW_STRINGS], range(len(RAW_STRINGS))))
CHAR_TO_INT = dict(zip([item for item in RAW_STRINGS], range(1, len(RAW_STRINGS) + 1)))
INT_TO_CHAR = dict(zip(list(CHAR_TO_INT.values()), list(CHAR_TO_INT.keys())))

START_TOKEN = "0"
END_TOKEN = "1"

def generate_data(data_size=100):
    
    def generate_cond_sequence(condition):
        for i in range(len(RAW_STRINGS)):
            if condition == RAW_STRINGS[i]:
                
                random_number = np.random.choice([0, 1, 2])
                
                if random_number == 0:
                    length = 3
                elif random_number == 1:
                    length = 6
                elif random_number == 2:
                    length = 10
                else:
                    raise NotImplementedError()
                
                return STRINGS[i:i+length]
        
        return RAW_STRINGS[::-1][:10]
    
    conditions_list = []
    sequences_lists = []
    
    for i in range(data_size):
        condition = np.random.choice([item for item in RAW_STRINGS])[0]
        
        seq = generate_cond_sequence(condition)
        
        conditions_list.append(condition)
        sequences_lists.append(seq)
        
    return conditions_list, sequences_lists


In [3]:
def add_start_end_tokens(conds, seqs):
    new_seqs = []
    for seq in seqs:
        new_seqs.append(START_TOKEN + seq + END_TOKEN)
        
    return conds, new_seqs
    
    
def make_same_length(conds, seqs):
    new_seqs = []
    
    max_length = np.max([len(seq) for seq in seqs])
    
    for seq in seqs:
        remaining_tokens = max_length - len(seq)
        new_seqs.append(seq + "".join([END_TOKEN]*remaining_tokens))
        
    return conds, new_seqs


def make_training_data(conds, seqs):
    
    input_seqs = []
    
    for seq in seqs:
        input_seqs.append(seq[:-1])
    
    input_data = (conds, input_seqs)
    
    output_seqs = []
    
    for seq in seqs:
        output_seqs.append(seq[1:])
        
    output_data = output_seqs
    
    return input_data, output_data

In [4]:
class ConditionMapper():
    
    def __init__(self):
        self.integer_map = CHAR_TO_INT.copy()
        self.string_map = INT_TO_CHAR.copy()
        
        # self.dimension = len(self.integer_map)
        self.dimension = len(self.integer_map) + 1
        
    def map_to_ints(self, input_string):
        input_numbers = [self.integer_map[item] for item in input_string]
        
        return np.asarray(input_numbers)
    
    def map_list_to_ints_vectors(self, list_of_strings):
        input_length = len(list_of_strings)
        
        vectors = []
        
        for input_string in list_of_strings:
            vectors.append(self.map_to_ints(input_string))
            
        vectors = np.asarray(vectors)
        
        return vectors
    
    def map_ints_to_string(self, input_ints):
        string = ""
        
        for integer in input_ints:
            string += self.string_map[integer]
            
        return string

In [5]:
class SeqMapper():
    
    def __init__(self):
        self.integer_map = CHAR_TO_INT.copy()
        
        n_ints = len(self.integer_map)
        
        # self.integer_map.update({START_TOKEN: n_ints, END_TOKEN: n_ints + 1})
        # self.dimension = len(self.integer_map)
        
        self.integer_map.update({START_TOKEN: n_ints + 1, END_TOKEN: n_ints + 2})
        
        self.dimension = len(self.integer_map) + 1
        
        self.string_map = dict(zip(list(self.integer_map.values()), list(self.integer_map.keys())))
    
    def map_to_ints(self, input_string):
        input_numbers = [self.integer_map[item] for item in input_string]
        
        return np.asarray(input_numbers)
    
    def map_list_to_ints_vectors(self, list_of_strings):
        input_length = len(list_of_strings)
        
        vectors = []
        
        for input_string in list_of_strings:
            vectors.append(self.map_to_ints(input_string))
            
        vectors = np.asarray(vectors)
        
        return vectors
    
    def map_ints_to_string(self, input_ints):
        string = ""
        
        for integer in input_ints:
            string += self.string_map[integer]
            
        return string


In [6]:
def convert_training_data(input_train, output_train, mappers):
    converted_input_train = mappers[0].map_list_to_ints_vectors(input_train[0]), mappers[1].map_list_to_ints_vectors(input_train[1])
    converted_output_train = mappers[1].map_list_to_ints_vectors(output_train)
    
    return converted_input_train, converted_output_train

In [7]:
# --------------------------------------- TENSORFLOW LAYERS AND MODELS -----------------------------------
    
# --------------------------------------- masked layers
def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
    return loss


def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match) / tf.reduce_sum(mask)    

In [8]:
# --------------------------------------- positional embedding
def positional_encoding(length, depth):
    depth = depth / 2

    positions = np.arange(length)[:, np.newaxis]  # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :] / depth  # (1, depth)

    angle_rates = 1 / (10000 ** depths)  # (1, depth)
    angle_rads = positions * angle_rates  # (pos, depth)

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1)

    return tf.cast(pos_encoding, dtype=tf.float32)


class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size,
                 d_model,
                 use_embedding=True):

        super().__init__()
        self.d_model = d_model
        self.use_embedding = use_embedding

        if self.use_embedding:
            self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        else:
            self.embedding = tf.keras.layers.Dense(d_model, activation="relu")

        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):
        if self.use_embedding:
            return self.embedding.compute_mask(*args, **kwargs)
        else:
            return None

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positional_encoding.
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

In [9]:
# --------------------------------------- attention layers
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()


class CrossAttention(BaseAttention):
    def call(self, x, context):
        attn_output, attn_scores = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True)

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x


class GlobalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x


class CausalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x,
            use_causal_mask=True)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x
    

class FeedForward(tf.keras.layers.Layer):
    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x)
        return x


class SelfAttentionFeedForwardLayer(tf.keras.layers.Layer):
    def __init__(self, *, d_model, num_heads, dff, dropout_rate=0.1,
                 attention="global"):

        super().__init__()

        if attention == "global":
            self.self_attention = GlobalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        elif attention == "causal":
            self.self_attention = CausalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        else:
            raise NotImplemented(f"The choice {attention} for attention is not implemented.")

        self.ffn = FeedForward(d_model, dff)

    def call(self, x):
        x = self.self_attention(x)
        x = self.ffn(x)
        return x


class SelfAttentionCrossAttentionFeedForwardLayer(tf.keras.layers.Layer):
    def __init__(self,
                 *,
                 d_model,
                 num_heads,
                 dff,
                 dropout_rate=0.1,
                 attention="causal"):

        super(SelfAttentionCrossAttentionFeedForwardLayer, self).__init__()

        if attention == "global":
            self.self_attention = GlobalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        elif attention == "causal":
            self.self_attention = CausalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        else:
            raise NotImplemented(f"The choice {attention} for attention is not implemented.")

        self.cross_attention = CrossAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x, context):
        x = self.self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        # Cache the last attention scores for plotting later
        self.last_attn_scores = self.cross_attention.last_attn_scores

        x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
        return x
    
# --------------------------------------- encoder
class Encoder(tf.keras.Model):
    def __init__(self, *, 
                 num_layers, 
                 d_model, 
                 num_heads,
                 dff, 
                 vocab_size, 
                 dropout_rate=0.1):
        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(
            vocab_size=vocab_size, d_model=d_model)

        self.enc_layers = [
            SelfAttentionFeedForwardLayer(d_model=d_model,
                                          num_heads=num_heads,
                                          dff=dff,
                                          dropout_rate=dropout_rate)
            for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x):
        x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # Shape `(batch_size, seq_len, d_model)`.
    
# --------------------------------------- decoder
class Decoder(tf.keras.Model):
    def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
                 dropout_rate=0.1):
        super().__init__()
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                 d_model=d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dec_layers = [
            SelfAttentionCrossAttentionFeedForwardLayer(d_model=d_model, num_heads=num_heads,
                                                        dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)]

        self.last_attn_scores = None

    def call(self, x, context):
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.dec_layers[i](x, context)

        self.last_attn_scores = self.dec_layers[-1].last_attn_scores

        return x # (batch_size, target_seq_len, d_model)

In [10]:
# --------------------------------------- transformer
class Transformer(tf.keras.Model):
    def __init__(self, *, 
                 num_layers, 
                 d_model, 
                 num_heads, 
                 dff,
                 input_vocab_size, 
                 target_vocab_size, 
                 dropout_rate=0.1):
        
        super().__init__()
        self.encoder = Encoder(num_layers=num_layers, 
                               d_model=d_model,
                               num_heads=num_heads, 
                               dff=dff,
                               vocab_size=input_vocab_size,
                               dropout_rate=dropout_rate)

        self.decoder = Decoder(num_layers=num_layers, 
                               d_model=d_model,
                               num_heads=num_heads, 
                               dff=dff,
                               vocab_size=target_vocab_size,
                               dropout_rate=dropout_rate)

        self.final_layer = tf.keras.Sequential([tf.keras.layers.Dense(target_vocab_size)])
        
    def save_model_weights(self, save_folder):
        self.encoder.save_weights(os.path.join(save_folder, "encoder.weights.h5"))
        self.decoder.save_weights(os.path.join(save_folder, "decoder.weights.h5"))
        self.final_layer.save_weights(os.path.join(save_folder, "final_layer.weights.h5"))

    def load_model_from_weights(self, save_folder):
        self.encoder.load_weights(os.path.join(save_folder, "encoder.weights.h5"))
        self.decoder.load_weights(os.path.join(save_folder, "decoder.weights.h5"))
        self.final_layer.load_weights(os.path.join(save_folder, "final_layer.weights.h5"))

    def get_models(self):
        return self.encoder, self.decoder

    def call(self, inputs):
        context, x = inputs
        context = self.encoder(context)  # (batch_size, context_len, d_model)
        x = self.decoder(x, context)  # (batch_size, target_len, d_model)
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
            # Drop the keras mask, so it doesn't scale the losses/metrics.
            # b/250038731
            del logits._keras_mask
        except AttributeError:
            pass

        return logits

In [11]:
# --------------------------------------- translator
class Translator(tf.Module):
    def __init__(self, transformer, mappers, sample=True, min_prob=0.05):
        self.cond_mapper = mappers[0]
        self.seq_mapper = mappers[1]
        self.transformer = transformer
        
        self.sample = sample
        
        assert 0 <= min_prob < 1
        
        self.min_prob = min_prob

    def __call__(self, cond_seq, max_length=15):
        if not isinstance(cond_seq, tf.Tensor):
            cond_seq = tf.convert_to_tensor(cond_seq)

        encoder_input = tf.expand_dims(cond_seq, 0)
        
        seq_start = self.seq_mapper.integer_map[START_TOKEN]
        seq_end = self.seq_mapper.integer_map[END_TOKEN]

        output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        output_array = output_array.write(0, seq_start)

        for i in tf.range(max_length):
            output = tf.transpose(output_array.stack())
            predictions = self.transformer([encoder_input, tf.expand_dims(output, 0)], training=False)

            predictions = predictions[0, -1:, :]  # Shape `(batch_size, 1, vocab_size)`.
            
            if self.sample:
                # calculate and transform probabilities
                probabilities = tf.math.softmax(predictions)
                high_probs = tf.where(probabilities > self.min_prob, probabilities, 0)[0]
                high_probs = high_probs / tf.math.reduce_sum(high_probs)
                
                # print(high_probs)
                
                # sample next id
                predicted_id = np.random.choice(len(high_probs), size=1, p=high_probs.numpy())[0]
            else:
                # take id with highest probability
                predicted_id = tf.argmax(predictions, axis=-1)[0]

            output_array = output_array.write(i + 1, predicted_id)

            if predicted_id == seq_end:
                break

        output = tf.transpose(output_array.stack())

        text = self.seq_mapper.map_ints_to_string(output.numpy())  # Shape: `()`.

        self.transformer([encoder_input, tf.expand_dims(output[:-1], 0)], training=False)
        attention_weights = self.transformer.decoder.last_attn_scores

        return text, attention_weights

In [12]:
# --------------------------------------- main use case
def main(train=True, 
         compute_results=True,
         n_conds=-1,
         n_samples=3,
         epochs=1, 
         checkpoint_name = "test_transformer"):
    
    # get data
    conditions, sequences = generate_data(data_size=5000)
    
    with open(os.path.join(checkpoint_name, "raw_data.txt"), "w") as f:
        print_dict = {}
        for i, cond in enumerate(conditions):
            if cond not in print_dict:
                print_dict[cond] = []
            
            print_dict[cond].append(sequences[i])
                
        with redirect_stdout(f):
            # print("conditions:")
            # pprint(conditions)
            # print("\n")
            # print("sequences")
            # pprint(sequences)
            
            pprint(print_dict)
                
    conditions, sequences = add_start_end_tokens(conditions, sequences)
    conditions, sequences = make_same_length(conditions, sequences)
    
    with open(os.path.join(checkpoint_name, "train_data.txt"), "w") as f:
        print_dict = {}
        for i, cond in enumerate(conditions):
            if cond not in print_dict:
                print_dict[cond] = []
            
            print_dict[cond].append(sequences[i])
                
        with redirect_stdout(f):
            # print("conditions:")
            # pprint(conditions)
            # print("\n")
            # print("sequences")
            # pprint(sequences)
            
            pprint(print_dict)
    
    cond_mapper = ConditionMapper()
    seq_mapper = SeqMapper()
    
    # generate training data
    raw_model_input_data, raw_model_output_data = make_training_data(conditions, sequences)
    model_input_data, model_output_data = convert_training_data(raw_model_input_data, 
                                                                raw_model_output_data, 
                                                                mappers=(cond_mapper, seq_mapper))
    
    # build transformer
    
    # tf example
    # num_layers = 4
    # d_model = 128
    # dff = 512
    # num_heads = 8
    # dropout_rate = 0.1
    
    transformer = Transformer(
        num_layers=3,
        d_model=32,
        num_heads=8,
        dff=128,
        input_vocab_size=cond_mapper.dimension,
        target_vocab_size=seq_mapper.dimension,
        dropout_rate=0.1,
    )
    
    transformer(model_input_data)
    transformer.summary()
    
    if not os.path.exists(checkpoint_name):
        os.mkdir(checkpoint_name)
    
    if os.path.exists(os.path.join(checkpoint_name, "encoder.weights.h5")):
        transformer.load_model_from_weights(checkpoint_name)
        
        print("Model loaded successfully!")
        print("\n")

    # compile transformer
    transformer.compile(
        loss=masked_loss,
        optimizer='Adam',
        metrics=[masked_accuracy])
    
    if train:
        print("---------------------------- Training model ----------------------------")
        
        # fit transformer on batches
        transformer.fit(model_input_data,
                        model_output_data,
                        epochs=epochs,
                        # validation_data=val_batches
                        )
        
        transformer.save_model_weights(checkpoint_name)
    
    # build translator
    translator = Translator(transformer, mappers=(cond_mapper, seq_mapper))
    
    # translate examples
    
    if compute_results:
        print("---------------------------- Using model ----------------------------")
            
        translation_results = dict()
        
        if n_conds == -1:
            strings_to_use = RAW_STRINGS
        elif 0 < n_conds < len(RAW_STRINGS):
            strings_to_use = RAW_STRINGS[:n_conds]
        else:
            raise NotImplementedError()
            
        for char in RAW_STRINGS:
            
            print(f"Generating output for condition {char}")
            
            translation_results[char] = []
            example_cond = cond_mapper.map_to_ints(char)
            
            for i in range(n_samples):
                model_text, _ = translator(example_cond)
                
                print(model_text)
                
                translation_results[char].append(model_text)
        
        with open(os.path.join(checkpoint_name, "results.txt"), "w") as f:
            with redirect_stdout(f):
                pprint(translation_results)

In [13]:
if __name__ == '__main__':
    main(train=True, 
         compute_results=True, 
         n_conds=-1,
         n_samples=10,
         epochs=5)

FileNotFoundError: [Errno 2] No such file or directory: 'test_transformer\\raw_data.txt'