# UTILS

In [None]:
!pip install transformers



In [None]:
from enum import Enum 

from transformers import AutoTokenizer

In [None]:
class ModelSettings(Enum):
  SEQUENCE_LENGTH = 20
  BATCH_SIZE = 64
  SHUFFLE_SEED = 2048
  PREFETCH = 16
  EMBEDDING_DIMENSION = 256
  LATENT_DIMENSION = 2048
  NUMBER_HEADS = 8
  EPOCHS = 1
  DROPOUT_RATE = 0.5
  MAXIMUM_GENERATION_LENGTH = 20
  ACTIVATION = "relu"
  OPTIMISER = "rmsprop"
  LOSS = "sparse_categorical_crossentropy"
  METRIC = "accuracy"
  TOKENISER = "t5-small"
  STYLE_QUESTION = "<|style:question|>"
  STYLE_STATEMENT = "<|style:statement|>"
  SENTIMENT_POSITIVE = "<|sentiment:positive|>"
  SENTIMENT_NEGATIVE = "<|sentiment:negative|>"
  SENTIMENT_NEUTRAL = "<|sentiment:neutral|>"

MAP_SPECIAL_TOKEN = dict(
    positive=ModelSettings.SENTIMENT_POSITIVE.value,
    negative=ModelSettings.SENTIMENT_NEGATIVE.value,
    neutral=ModelSettings.SENTIMENT_NEUTRAL.value,
    statement=ModelSettings.STYLE_STATEMENT.value,
    question=ModelSettings.STYLE_QUESTION.value
)
TOKENISER = AutoTokenizer.from_pretrained(ModelSettings.TOKENISER.value)
TOKENISER.add_special_tokens(
    dict(
        bos_token="<|startoftext|>",
        additional_special_tokens=[
            ModelSettings.STYLE_QUESTION.value,
            ModelSettings.STYLE_STATEMENT.value,
            ModelSettings.SENTIMENT_POSITIVE.value,
            ModelSettings.SENTIMENT_NEGATIVE.value,
            ModelSettings.SENTIMENT_NEUTRAL.value,
        ]
    )
)

6

# DATA

In [None]:
train_pairs = [
(('positive','statement',['police','crime']), 'The police attended a crime'),
(('negative','statement',['hope','verdict']), 'Her hopes were dashed when she heard the verdict.'),
(('neutral','statement',['relax']),'Please relax.'),
(('positive','question',['japan','food']),'I eat Japanese food.'),
(('negative','question',['mary','present','daughter']),"Mary bought a present for her friend's daughter.")
]

In [None]:
from typing import Tuple, Dict, List


In [None]:
def format_sentence(sentence:str) -> str:
  return f"{TOKENISER.bos_token} {sentence} {TOKENISER.eos_token}" 

def format_condition(style:str,sentiment:str,keywords:List[str]) -> str:
  return f"{MAP_SPECIAL_TOKEN[style]} {MAP_SPECIAL_TOKEN[sentiment]} {' '.join(keywords)}"

In [None]:
train_pairs = list(map(
    lambda input_condition,output_sentence:(
        format_condition(*input_condition),
        format_sentence(output_sentence)
    ),
    *zip(*train_pairs)
))

In [None]:
train_pairs

[('<|sentiment:positive|> <|style:statement|> police crime',
  '<|startoftext|> The police attended a crime </s>'),
 ('<|sentiment:negative|> <|style:statement|> hope verdict',
  '<|startoftext|> Her hopes were dashed when she heard the verdict. </s>'),
 ('<|sentiment:neutral|> <|style:statement|> relax',
  '<|startoftext|> Please relax. </s>'),
 ('<|sentiment:positive|> <|style:question|> japan food',
  '<|startoftext|> I eat Japanese food. </s>'),
 ('<|sentiment:negative|> <|style:question|> mary present daughter',
  "<|startoftext|> Mary bought a present for her friend's daughter. </s>")]

# Data Encoding

In [None]:
from tensorflow.data import Dataset 

In [None]:
input_vectorisation = lambda text: TOKENISER.encode(str(text),return_tensors='tf',padding='max_length',max_length=ModelSettings.SEQUENCE_LENGTH.value, truncation=True)
output_vectorisation = lambda text: TOKENISER.encode(str(text),return_tensors='tf',padding='max_length',max_length=1+ModelSettings.SEQUENCE_LENGTH.value, truncation=True)

In [None]:
def format_data(input_condition:str, output_sentence:str) -> Tuple[Dict[str,str],str]:
    encoder_input_ids = input_vectorisation(input_condition)
    sentence_ids = output_vectorisation(output_sentence)
    decoder_input_ids = sentence_ids[:,:-1]
    decoder_output_ids = sentence_ids[:, 1:]
    inputs = dict(
      encoder_inputs= encoder_input_ids, 
      decoder_inputs= decoder_input_ids
    )
    return inputs, decoder_output_ids

In [None]:
def make_dataset(pairs:List[Tuple[str,str]]) -> Dataset:
    train_inputs,train_outputs = zip(*pairs)
    dataset = Dataset.from_tensor_slices((list(train_inputs),list(train_outputs)))
    dataset = dataset.batch(ModelSettings.BATCH_SIZE.value)
    dataset = dataset.map(format_data)
    return dataset.shuffle(ModelSettings.SHUFFLE_SEED.value).prefetch(ModelSettings.PREFETCH.value).cache()

In [None]:
train_dataset = make_dataset(train_pairs)
validation_dataset = make_dataset(train_pairs)

# Model

In [None]:
from tensorflow.keras.layers import Layer, MultiHeadAttention, Dense, LayerNormalization, Embedding, Dropout
from tensorflow import cast, newaxis, shape, range, minimum, reshape, concat, tile, expand_dims, constant
from tensorflow.math import not_equal
from keras import Sequential, Input, Model

In [None]:
class TransformerEncoder(Layer):
    def __init__(self) -> None:
        super().__init__()
        self.embed_dim = ModelSettings.EMBEDDING_DIMENSION.value
        self.dense_dim = ModelSettings.LATENT_DIMENSION.value
        self.num_heads = ModelSettings.NUMBER_HEADS.value
        self.attention = MultiHeadAttention(
            num_heads=ModelSettings.NUMBER_HEADS.value, 
            key_dim=ModelSettings.EMBEDDING_DIMENSION.value
        )
        self.dense_projection = Sequential([
            Dense(ModelSettings.LATENT_DIMENSION.value, activation="relu"), 
            Dense(ModelSettings.EMBEDDING_DIMENSION.value)
        ])
        self.layernorm_1 = LayerNormalization()
        self.layernorm_2 = LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask):
        attended_input = self.attention(
            query=inputs, 
            value=inputs, 
            key=inputs, 
            attention_mask=cast(mask[:, newaxis, newaxis, :], dtype="int32")
        )
        projected_input = self.layernorm_1(inputs + attended_input)
        projected_output = self.dense_projection(projected_input)
        return self.layernorm_2(projected_input + projected_output)

In [None]:
class PositionalEmbedding(Layer):
    def __init__(self) -> None:
        super().__init__()
        self.token_embeddings = Embedding(
            input_dim=len(TOKENISER), 
            output_dim=ModelSettings.EMBEDDING_DIMENSION.value
        )
        self.position_embeddings = Embedding(
            input_dim=ModelSettings.SEQUENCE_LENGTH.value, 
            output_dim=ModelSettings.EMBEDDING_DIMENSION.value
        )
        self.sequence_length = ModelSettings.SEQUENCE_LENGTH.value
        self.vocab_size = len(TOKENISER)
        self.embed_dim = ModelSettings.LATENT_DIMENSION.value

    def call(self, inputs):
        positions = range(start=0, limit=shape(inputs)[-1], delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    @staticmethod
    def compute_mask(inputs, mask=None):
        return not_equal(inputs, 0)

In [None]:
class TransformerDecoder(Layer):
    def __init__(self):
        super().__init__()
        self.embed_dim = ModelSettings.EMBEDDING_DIMENSION.value
        self.latent_dim = ModelSettings.LATENT_DIMENSION.value
        self.num_heads = ModelSettings.NUMBER_HEADS.value
        self.attention_1 = MultiHeadAttention(
            num_heads=ModelSettings.NUMBER_HEADS.value, 
            key_dim=ModelSettings.EMBEDDING_DIMENSION.value
        )
        self.attention_2 = MultiHeadAttention(
            num_heads=ModelSettings.NUMBER_HEADS.value, 
            key_dim=ModelSettings.EMBEDDING_DIMENSION.value
        )
        self.dense_projtion = Sequential([
          Dense(ModelSettings.LATENT_DIMENSION.value, activation=ModelSettings.ACTIVATION.value), 
          Dense(ModelSettings.EMBEDDING_DIMENSION.value)
        ])
        self.layernorm_1 = LayerNormalization()
        self.layernorm_2 = LayerNormalization()
        self.layernorm_3 = LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        if mask is not None:
            padding_mask = cast(mask[:, newaxis, :], dtype="int32")
            padding_mask = minimum(padding_mask, causal_mask)

        attended_inputs = self.attention_1(
            query=inputs, 
            value=inputs, 
            key=inputs, 
            attention_mask=causal_mask
        )
        decoder_inputs = self.layernorm_1(inputs + attended_inputs)

        attended_decoder_inputs = self.attention_2(
            query=decoder_inputs,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
        )
        outputs = self.layernorm_2(decoder_inputs + attended_decoder_inputs)

        projected_outputs = self.dense_projtion(outputs)
        return self.layernorm_3(outputs + projected_outputs)

    @staticmethod
    def get_causal_attention_mask(inputs):
        input_shape = shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = range(sequence_length)[:, newaxis]
        j = range(sequence_length)
        mask = cast(i >= j, dtype="int32")
        mask = reshape(mask, (1, sequence_length, sequence_length))
        multiple = concat([
            expand_dims(batch_size, -1), 
            constant([1, 1], dtype="int32")
        ],axis=0)
        return tile(mask, multiple)

In [None]:
from numpy import argmax, ndarray

class ConditionalGeneratorTransformer:
  def __init__(self) -> None:
      self.model = self._build_model()
  
  def train(self,data:Dataset, validation_data:Dataset) -> None:
      self.model.fit(data, epochs=ModelSettings.EPOCHS.value, validation_data=validation_data)

  def generate(self, style:str, sentiment:str, keywords:List[str]) -> str:
      input_condition = format_condition(style,sentiment,keywords)
      tokenised_input_condition = input_vectorisation([input_condition])
      decoded_sentence = TOKENISER.bos_token

      for position in range(ModelSettings.MAXIMUM_GENERATION_LENGTH.value):
          tokenised_target_sentence = output_vectorisation([decoded_sentence])[:, :-1]

          logits = self.model([tokenised_input_condition, tokenised_target_sentence])

          predicted_token_id = self._greedy_decode(logits, position)
          predicted_token = TOKENISER.decode(predicted_token_id)
          if predicted_token == TOKENISER.eos_token:
              break
          decoded_sentence += f" {predicted_token}"

      return decoded_sentence
  
  @staticmethod
  def _greedy_decode(logits:ndarray,position:int) -> int:
    return argmax(logits[0, position, :])          

  @staticmethod
  def _build_model() -> Model:
      encoder_inputs = Input(shape=(None,), dtype="int64", name="encoder_inputs")
      contextualised_encoder_inputs = PositionalEmbedding()(encoder_inputs)
      encoder_outputs = TransformerEncoder()(contextualised_encoder_inputs)
      encoder = Model(encoder_inputs, encoder_outputs)

      decoder_inputs = Input(shape=(None,), dtype="int64", name="decoder_inputs")
      encoded_inputs = Input(shape=(None, ModelSettings.EMBEDDING_DIMENSION.value), name="decoder_state_inputs")
      contextualised_decoder_inputs = PositionalEmbedding()(decoder_inputs)
      projected_decoder_inputs = TransformerDecoder()(contextualised_decoder_inputs, encoded_inputs)
      decoder_logits = Dropout(ModelSettings.DROPOUT_RATE.value)(projected_decoder_inputs)
      decoder_outputs = Dense(len(TOKENISER), activation="softmax")(decoder_logits)
      decoder = Model([decoder_inputs, encoded_inputs], decoder_outputs)

      transformer_outputs = decoder([decoder_inputs, encoder_outputs])

      transformer = Model(
          [encoder_inputs, decoder_inputs], 
          transformer_outputs, 
          name="transformer"
      )
      transformer.compile(ModelSettings.OPTIMISER.value, loss=ModelSettings.LOSS.value, metrics=[ModelSettings.METRIC.value])
      return transformer

# Train

In [None]:
x = ConditionalGeneratorTransformer()
x.train(train_dataset, validation_dataset)



# Test

In [None]:
keywords = ["car","made"]

In [None]:
x.generate("statement","neutral",keywords)

'<|startoftext|> s s ( s _ s _ ( s _ s : = ( s one  s '

In [None]:
x.generate("statement","negative",keywords)

In [None]:
x.generate("statement","positive",keywords)

In [None]:
x.generate("question","neutral",keywords)

In [None]:
x.generate("question","negative",keywords)

In [None]:
x.generate("question","positive",keywords)