<a href="https://colab.research.google.com/github/elliot-brooks/nlu-coursework/blob/main/src/AV_LSTM_TRAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!pip install -U transformers
!pip install -U accelerate

In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
from transformers import DistilBertTokenizer, TFDistilBertModel
import nltk
import re

# Load training data

In [None]:
training_corpus = pd.read_csv("train.csv", encoding='utf-8')

# Pre-process training data

In [None]:
# Case fold to lower-case
def preprocess(string):
  output = str(string).lower()
  separated_string = re.sub(r'([^\w\s])', r' \1 ', str(string))
  return output

# Prepare data for Distilled Bert by concatenating pairs with [SEP] token
def prepare_data(data) :
  data["text_1"] = data["text_1"].apply(lambda x: preprocess(x))
  data["text_2"] = data["text_2"].apply(lambda x: preprocess(x))
  concat_pairs = []
  for index, row in data.iterrows():
      concatenated_pair = row["text_1"] + " [SEP] " + row["text_2"]
      concat_pairs.append(concatenated_pair)
  return concat_pairs

concat_data = prepare_data(training_corpus)

Create BERT embeddings

In [None]:
tokeniser = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased')

In [None]:
SEQ_LENGTH = 256
BATCH_SIZE = 32
def create_bert_embeddings_batch(texts, tokeniser, model, batch_size, seq_length) :
  embeddings = []
  for i in range(0, len(texts), batch_size) :
    batch = texts[i:i + batch_size]
    inputs = tokeniser.batch_encode_plus(batch, padding='max_length', truncation=True, return_tensors='tf', max_length=seq_length, add_special_tokens=True)

    # Create embeddings
    outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])

    last_hidden_state_CLS = outputs.last_hidden_state[:, 0, :]

    embeddings.append(last_hidden_state_CLS)
  return embeddings

bert_embeddings = create_bert_embeddings_batch(concat_data, tokeniser, bert_model, BATCH_SIZE, SEQ_LENGTH)
train_labels = np.array(training_corpus['label'])

# Define classification model

In [None]:
LSTM_UNITS = 128
DROPOUT_RATE = 0.2
model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(768, 1)),
        tf.keras.layers.LSTM(LSTM_UNITS, dropout=DROPOUT_RATE),
        tf.keras.layers.Dense(1, activation='sigmoid')
])


# Summarise Model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_5 (InputLayer)        [(None, 150)]                0         []                            
                                                                                                  
 input_6 (InputLayer)        [(None, 150)]                0         []                            
                                                                                                  
 embedding_2 (Embedding)     (None, 150, 100)             1062880   ['input_5[0][0]',             
                                                          0          'input_6[0][0]']             
                                                                                                  
 lstm_2 (LSTM)               (None, 128)                  117248    ['embedding_2[0][0]',   

# Train Model

In [None]:
train_inputs = np.concatenate(bert_embeddings, axis=0).reshape(-1, 768, 1)
print(train_inputs.shape)
model.fit(train_inputs, train_labels, epochs=100, batch_size=128)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x7fdd1c6c1510>

# Save Model

In [None]:
model.save("AV_LSTM_MODEL")

In [None]:
# Zip model file
!zip -r /content/LSTM_MODEL.zip /content/AV_LSTM_MODEL

updating: content/AV_LSTM_MODEL/ (stored 0%)
updating: content/AV_LSTM_MODEL/keras_metadata.pb (deflated 88%)
updating: content/AV_LSTM_MODEL/variables/ (stored 0%)
updating: content/AV_LSTM_MODEL/variables/variables.index (deflated 59%)
updating: content/AV_LSTM_MODEL/variables/variables.data-00000-of-00001 (deflated 10%)
updating: content/AV_LSTM_MODEL/fingerprint.pb (stored 0%)
updating: content/AV_LSTM_MODEL/assets/ (stored 0%)
updating: content/AV_LSTM_MODEL/saved_model.pb (deflated 90%)
