In [3]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# =============================================================================
# 1. Define custom layers for embeddings and a single transformer block.
# =============================================================================

class PositionalEmbedding(layers.Layer):
    def __init__(self, max_len, vocab_size, embed_dim):
        super(PositionalEmbedding, self).__init__()
        # Token embeddings: converts token ids to vectors.
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        # Positional embeddings: each position in the sequence gets its own vector.
        self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
        self.max_len = max_len

    def call(self, x):
        # x shape: (batch, sequence_length)
        positions = tf.range(start=0, limit=tf.shape(x)[1], delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        # Multi-head self-attention; note the use of a causal mask below.
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=dropout_rate)
        # Feed-forward network.
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"),
             layers.Dense(embed_dim)]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.dropout2 = layers.Dropout(dropout_rate)

    def call(self, x, training):
        # Apply causal self-attention so that each position can only attend to previous ones.
        attn_output = self.att(x, x, x, use_causal_mask=True)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

# =============================================================================
# 2. Build the decoder-only transformer model.
# =============================================================================

def create_transformer_model(max_len, vocab_size, embed_dim, num_heads, ff_dim, num_layers):
    inputs = keras.Input(shape=(max_len,), dtype="int32")
    x = PositionalEmbedding(max_len, vocab_size, embed_dim)(inputs)
    for _ in range(num_layers):
        x = TransformerBlock(embed_dim, num_heads, ff_dim)(x, training=True)
    outputs = layers.Dense(vocab_size)(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

# =============================================================================
# 3. Sample data: pairs of (abstract, title).
# =============================================================================

data_samples = [
    (
        "In this paper, we propose a novel transformer-based approach to sequence modeling.",
        "Transformer Approach"
    ),
    (
        "We present an analysis of deep neural networks in image recognition tasks.",
        "Deep Learning in Vision"
    ),
    (
        "A comprehensive study on reinforcement learning methods and their applications.",
        "Reinforcement Learning Study"
    )
]

# We will combine abstract and title with a special separator token.
separator = " sep "
combined_texts = [abstract + separator + title for abstract, title in data_samples]

# =============================================================================
# 4. Tokenization using Keras TextVectorization.
# =============================================================================

max_tokens = 1000
max_len = 50  # maximum sequence length for our model
vectorizer = layers.TextVectorization(max_tokens=max_tokens, output_mode="int", output_sequence_length=max_len)
vectorizer.adapt(combined_texts)

# Retrieve the vocabulary; we need to know the id of the separator token.
vocab = vectorizer.get_vocabulary()
sep_token = separator.strip()  # e.g., "<sep>"
if sep_token not in vocab:
    raise ValueError("Separator token not found in vocabulary!")
sep_token_id = vocab.index(sep_token)

# =============================================================================
# 5. Create the training data with loss mask.
#
# For a GPT-style (decoder-only) training, we input a sequence that is the
# concatenation of abstract and title. The target is the same sequence shifted
# one token to the left. We then create a sample weight mask so that only tokens
# corresponding to the title (i.e. those after the separator) contribute to loss.
# =============================================================================

def create_inputs_targets(token_seq):
    # token_seq shape: (max_len,)
    # Find the first occurrence of the separator token.
    sep_positions = tf.where(tf.equal(token_seq, sep_token_id))
    sep_index = tf.cast(sep_positions[0][0], tf.int32)
    inp = token_seq[:-1]    # all tokens except the last one
    target = token_seq[1:]  # shifted by one position
    seq_length = tf.shape(target)[0]
    # For positions j in target, set weight = 1 if (j+1) (the corresponding original token)
    # comes after the separator; else weight = 0.
    indices = tf.range(seq_length)
    weight = tf.cast(indices >= sep_index, tf.float32)
    # Also mask out padding tokens (assumed to be 0).
    non_padding = tf.cast(tf.not_equal(target, 0), tf.float32)
    weight = weight * non_padding
    return inp, target, weight

# Build a tf.data.Dataset.
dataset = tf.data.Dataset.from_tensor_slices(combined_texts)
dataset = dataset.map(lambda t: vectorizer(t))
dataset = dataset.map(create_inputs_targets)
BATCH_SIZE = 2
dataset = dataset.shuffle(10).batch(BATCH_SIZE)

# =============================================================================
# 6. Create and compile the model.
# =============================================================================

vocab_size = len(vocab)
embed_dim = 64
num_heads = 4
ff_dim = 128
num_layers = 2

model = create_transformer_model(max_len, vocab_size, embed_dim, num_heads, ff_dim, num_layers)
model.compile(optimizer="adam",
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.summary()

# =============================================================================
# 7. Train the model.
# =============================================================================

model.fit(dataset, epochs=10)

# =============================================================================
# 8. Inference: Generate a title given an abstract.
#
# We provide only the abstract plus the separator as a prompt. Then we use a
# simple greedy decoding loop to generate a fixed number of tokens. The generated
# sequence is decoded and the title part (the words after the separator) is returned.
# =============================================================================

def generate_title(abstract, max_gen_len=10):
    prompt = abstract + separator
    # Vectorize the prompt; output shape: (1, sequence_length)
    token_seq = vectorizer([prompt])
    token_seq = tf.squeeze(token_seq, axis=0)  # shape: (max_len,)
    # Remove trailing padding (assuming padding token is 0)
    token_seq = token_seq[token_seq != 0]

    for _ in range(max_gen_len):
        # Pad current sequence to max_len (our model input length)
        inp = tf.expand_dims(token_seq, 0)
        inp = tf.pad(inp, [[0, 0], [0, max_len - tf.shape(inp)[1]]])
        predictions = model(inp, training=False)  # shape: (1, max_len, vocab_size)
        # Pick the token at the current last position.
        next_token_logits = predictions[0, tf.shape(token_seq)[0]-1, :]
        next_token = tf.argmax(next_token_logits, axis=-1, output_type=tf.int32)
        if next_token == 0:
            break  # stop if the model generates a padding token
        token_seq = tf.concat([token_seq, [next_token]], axis=0)
        if tf.shape(token_seq)[0] >= max_len:
            break

    # Decode token ids back to words.
    inv_vocab = {i: word for i, word in enumerate(vocab)}
    generated_words = [inv_vocab.get(int(token), "") for token in token_seq.numpy().tolist()]
    # Find the separator token and take the words after it as the generated title.
    if sep_token in generated_words:
        idx = generated_words.index(sep_token)
        title_words = generated_words[idx+1:]
    else:
        title_words = generated_words
    return " ".join(title_words).strip()

# Test the inference routine.
test_abstract = ("This study introduces a new method for natural language understanding "
                 "using attention mechanisms to better capture contextual dependencies.")
generated_title = generate_title(test_abstract)
print("Generated Title:", generated_title)


wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\mhuep\_netrc


<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 390), dtype=tf.int32, name=None), TensorSpec(shape=(None, 390), dtype=tf.int32, name=None))>