In [1]:
import numpy as np

def reweight_distribution(original_distribution, temperature=0.5):
    distribution = np.log(original_distribution) / temperature
    distribution = np.exp(distribution)
    return distribution / np.sum(distribution)

In [None]:
!wget https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

In [None]:
import tensorflow as tf
from tensorflow import keras

dataset = keras.utils.text_dataset_from_directory(directory='./aclImdb', label_mode=None, batch_size=256)
dataset = dataset.map(lambda x: tf.strings.regex_replace(x, "<br />", " "))

In [4]:
from tensorflow.keras.layers import TextVectorization

sequence_length = 100
vocab_size = 15000

text_vectorization = TextVectorization(
    max_tokens=vocab_size,
    output_mode='int',
    output_sequence_length=sequence_length
)

text_vectorization.adapt(dataset)

In [5]:
def prepare_lm_dataset(text_batch):
    vectorized_sequences = text_vectorization(text_batch)
    x = vectorized_sequences[:, :-1]
    y = vectorized_sequences[:, 1:]
    return x, y

lm_dataset = dataset.map(prepare_lm_dataset)

In [6]:
from tensorflow.keras import layers

In [7]:
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.attention2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.dense_projection = keras.Sequential([layers.Dense(dense_dim, activation="relu"), layers.Dense(embed_dim)])
        self.layer_norm1 = layers.LayerNormalization()
        self.layer_norm2 = layers.LayerNormalization()
        self.layer_norm3 = layers.LayerNormalization()
        self.supports_masking = True
        
    def get_config(self):
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
            'dense_dim': self.dense_dim,
            'num_heads': self.num_heads
        })
        return config
    
    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, seq_length = input_shape[0], input_shape[1]
        i = tf.range(seq_length)[:, tf.newaxis]
        j = tf.range(seq_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat([tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype="int32")], axis=0)
        return tf.tile(mask, mult)
    
    def call(self, inputs, enconder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        if mask is not None:
            padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32")
            padding_mask = tf.minimum(padding_mask, causal_mask)
        attention_output_1 = self.attention1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=causal_mask 
        )
        attention_output_1 = self.layer_norm1(inputs + attention_output_1)
        attention_output_2 = self.attention2(
            query=attention_output_1, value=enconder_outputs, key=enconder_outputs, attention_mask=padding_mask
        )
        attention_output_2 = self.layer_norm2(attention_output_1 + attention_output_2)
        projection_output = self.dense_projection(attention_output_2)
        return self.layer_norm3(projection_output + attention_output_2)

In [13]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embedding = layers.Embedding(input_dim=input_dim, output_dim=output_dim)
        self.position_embedding = layers.Embedding(input_dim=sequence_length, output_dim=output_dim)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim
        
    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embedding(inputs)
        embedded_positions = self.position_embedding(positions)
        embedding = embedded_tokens + embedded_positions
        return embedding
    
    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "output_dim": self.output_dim,
            "sequence_length": self.sequence_length,
            "input_dim": self.input_dim
        })
        return config

In [20]:
embed_dim = 256
latent_dim = 2048
num_heads = 2

inputs = keras.Input(shape=(None, ), dtype='int64')
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, x)
outputs = keras.layers.Dense(vocab_size, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss='sparse_categorical_crossentropy', optimizer='rmsprop')

In [21]:
tokens_index = dict(enumerate(text_vectorization.get_vocabulary()))

def sample_next(predictions, temperature=1.0):
    predictions = np.asarray(predictions).astype('float64')
    predictions = np.log(predictions) / temperature
    exp_predictions = np.exp(predictions)
    predictions = exp_predictions / np.sum(exp_predictions)
    probas = np.random.multinomial(1, predictions, 1)
    return np.argmax(probas)

class TextGenerator(keras.callbacks.Callback):
    def __init__(self, prompt, generate_length, model_input_length, temperatures=(1.,), print_freq=1):
        self.prompt = prompt
        self.generate_length = generate_length
        self.model_input_length = model_input_length
        self.temperatures = temperatures
        self.print_freq = print_freq
        
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.print_freq != 0:
            return
        for temperature in self.temperatures:
            print("Temperature = {}".format(temperature))
            sentence = self.prompt
            for i in range(self.generate_length):
                tokenized_sentence = text_vectorization([sentence])
                predictions = self.model.predict(tokenized_sentence, verbose=0)
                next_token = sample_next(predictions[0, i, :])
                sampled_token = tokens_index[next_token]
                sentence += " " + sampled_token
            print(sentence)

In [22]:
prompt = "This movie"
text_gen_callback = TextGenerator(prompt, generate_length=50, model_input_length=sequence_length, temperatures=(0.2, 0.7, 1.2), print_freq=10)

In [23]:
model.fit(lm_dataset, epochs=200, callbacks=[text_gen_callback])

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
This movie is is extremely absolutely bad incredible a no b wait credibility for how a does movie its has just a someone bunch discussed of so the bad movie movie bad turned he out discovers of that cheesy do idea with that this bright sort american of version plot would big
Temperature = 0.7
This movie is was okay amazing but maybe i a saw hundred it years the ago only i thing bought about it danny again [UNK] for but this the talent worst and movie action there movie are never four did total not daniels call dont it let ranks me in there the
Temperature = 1.2
This movie is was devils a the pilot worst who series wrote of but their just lifetime [UNK] done deol poorly really done are poorly required handled for quite the some dont of look alice of and developing serves a had big promise shots of since human its plans release the the
Epoch 11/200
Epoch 12/200
Epoch 1

<keras.src.callbacks.History at 0x7f5b8112a490>