## This project implements N-gram language modeling and RNN language modeling on text dataset.


In [1]:
import sys


def print_line(*args):
    """ Inline print and go to the begining of line
    """
    args1 = [str(arg) for arg in args]
    str_ = ' '.join(args1)
    print('\r' + str_, end='')

In [3]:
import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [4]:
import tensorflow as tf

# Check TensorFlow version and GPU availability
print("TensorFlow version:", tf.__version__)
print("Is GPU available?", tf.config.list_physical_devices('GPU'))

TensorFlow version: 2.10.0
Is GPU available? [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [5]:
from typing import List, Tuple, Union, Dict
from collections import defaultdict
import numpy as np

In [7]:
import os
import pickle


data_path = './datasets/a3-data'

train_sentences = open(os.path.join(data_path, 'train.txt')).readlines()
valid_sentences = open(os.path.join(data_path, 'valid.txt')).readlines()
test_sentences = open(os.path.join(data_path, 'input.txt')).readlines()
print('number of train sentences:', len(train_sentences))
print('number of valid sentences:', len(valid_sentences))
print('number of test sentences:', len(test_sentences))

number of train sentences: 42068
number of valid sentences: 3370
number of test sentences: 3165


In [8]:
import re


class Preprocessor:
    def __init__(self, punctuation=True, url=True, number=True):
        self.punctuation = punctuation
        self.url = url
        self.number = number

    def apply(self, sentence: str) -> str:
        """ Apply the preprocessing rules to the sentence
        Args:
            sentence: raw sentence
        Returns:
            sentence: clean sentence
        """
        sentence = sentence.lower()
        sentence = sentence.replace('<unk>', '')
        if self.url:
            sentence = Preprocessor.remove_url(sentence)
        if self.punctuation:
            sentence = Preprocessor.remove_punctuation(sentence)
        if self.number:
            sentence = Preprocessor.remove_number(sentence)
        sentence = re.sub(r'\s+', ' ', sentence)
        return sentence

    @staticmethod
    def remove_punctuation(sentence: str) -> str:
        """ Remove punctuations in sentence with re
        Args:
            sentence: sentence with possible punctuations
        Returns:
            sentence: sentence without punctuations
        """
        sentence = re.sub(r'[^\w\s]', ' ', sentence)
        return sentence

    @staticmethod
    def remove_url(sentence: str) -> str:
        """ Remove urls in text with re
        Args:
            sentence: sentence with possible urls
        Returns:
            sentence: sentence without urls
        """
        sentence = re.sub(r'(https|http)?://(\w|\.|/|\?|=|&|%)*\b', ' ', sentence)
        return sentence

    @staticmethod
    def remove_number(sentence: str) -> str:
        """ Remove numbers in sentence with re
        Args:
            sentence: sentence with possible numbers
        Returns:
            sentence: sentence without numbers
        """
        sentence = re.sub(r'\d+', ' ', sentence)
        return sentence

In [9]:
class Tokenizer:
    def __init__(self, sos_token='<s>', eos_token='</s>', pad_token='<pad>', unk_token='<unk>', mask_token='<mask>'):
        # Special tokens.
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.mask_token = mask_token
        
        self.vocab = { sos_token: 0, eos_token: 1, pad_token: 2, unk_token: 3, mask_token: 4 }  # token -> id
        self.inverse_vocab = { 0: sos_token, 1: eos_token, 2: pad_token, 3: unk_token, 4: mask_token }  # id -> token
        self.token_occurrence = { sos_token: 0, eos_token: 0, pad_token: 0, unk_token: 0, mask_token: 0 }  # token -> occurrence
        
        self.preprocessor = Preprocessor()

    @property
    def sos_token_id(self):
        """ Create a property method.
            You can use self.sos_token_id or tokenizer.sos_token_id to get the id of the sos_token.
        """
        return self.vocab[self.sos_token]

    @property
    def eos_token_id(self):
        return self.vocab[self.eos_token]

    @property
    def pad_token_id(self):
        return self.vocab[self.pad_token]

    @property
    def unk_token_id(self):
        return self.vocab[self.unk_token]

    @property
    def mask_token_id(self):
        return self.vocab[self.mask_token]
        
    def __len__(self):
        """ A magic method that enable program to know the number of tokens by calling:
            ```python
            tokenizer = Tokenizer()
            num_tokens = len(tokenizer)
            ```
        """
        return len(self.vocab)
        
    def fit(self, sentences: List[str]):
        """ Fit the tokenizer using all sentences.
        Args:
            sentences: All sentences in the dataset.
        """
        n = len(sentences)
        for i, sentence in enumerate(sentences):
            if i % 100 == 0 or i == n - 1:
                print_line('Fitting Tokenizer:', (i + 1), '/', n)
            tokens = self.preprocessor.apply(sentence.strip()).split()
            if len(tokens) <= 1:
                continue
            for token in tokens:
                if token == '<unk>':
                    continue
                self.token_occurrence[token] = self.token_occurrence.get(token, 0) + 1
        print_line('\n')

        token_occurrence = sorted(self.token_occurrence.items(), key=lambda e: e[1], reverse=True)
        for token, occurrence in token_occurrence[:-5]:
            token_id = len(self.vocab)
            self.vocab[token] = token_id
            self.inverse_vocab[token_id] = token

        print('The number of distinct tokens:', len(self.vocab))
        
    def encode(self, sentences: List[str]) -> List[List[int]]:
        """ Encode the sentences into token ids
        Args:
            sentences: Raw sentences
        Returns:
            sent_token_ids: A list of id list
        """
        n = len(sentences)
        sent_token_ids = []
        for i, sentence in enumerate(sentences):
            if i % 100 == 0 or i == n - 1:
                print_line('Encoding with Tokenizer:', (i + 1), '/', n)
            token_ids = []
            tokens = self.preprocessor.apply(sentence.strip()).split()
            for token in tokens:
                if token == '<unk>':
                    continue
                if token in self.vocab:
                    token_ids.append(self.vocab[token])
            if len(token_ids) <= 1:
                continue
            token_ids = [self.sos_token_id] + token_ids + [self.eos_token_id]
            sent_token_ids.append(token_ids)
        print_line('\n')
        return sent_token_ids

In [10]:
tokenizer = Tokenizer()
tokenizer.fit(train_sentences[:2])
print()

token_occurrence = sorted(tokenizer.token_occurrence.items(), key=lambda e: e[1], reverse=True)
for token, occurrence in token_occurrence[:10]:
    print(token, ':', occurrence)
print()
sent_token_ids = tokenizer.encode(train_sentences[:2])
print()
for original_sentence, token_ids in zip(train_sentences[:2], sent_token_ids):
    sentence = [tokenizer.inverse_vocab[token] for token in token_ids]
    print(original_sentence, sentence, '\n')

Fitting Tokenizer: 2 / 2
The number of distinct tokens: 44

n : 2
aer : 1
banknote : 1
berlitz : 1
calloway : 1
centrust : 1
cluett : 1
fromstein : 1
gitano : 1
guterman : 1

Encoding with Tokenizer: 2 / 2

 aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter 
 ['<s>', 'aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro', 'quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack', 'food', 'ssangyong', 'swapo', 'wachter', '</s>'] 

 pierre <unk> N years old will join the board as a nonexecutive director nov. N 
 ['<s>', 'pierre', 'n', 'years', 'old', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'nov', 'n', '</s>'] 



In [11]:
tokenizer = Tokenizer()
tokenizer.fit(train_sentences)
train_token_ids = tokenizer.encode(train_sentences)
valid_token_ids = tokenizer.encode(valid_sentences)
test_token_ids = tokenizer.encode(test_sentences)

Fitting Tokenizer: 42068 / 42068
The number of distinct tokens: 9614
Encoding with Tokenizer: 42068 / 42068
Encoding with Tokenizer: 3370 / 3370
Encoding with Tokenizer: 3165 / 3165


In [12]:
def get_unigram_count(train_token_ids: List[List[int]]) -> Dict:
    """ Calculate the occurrence of each token in the dataset.
    
    Args:
        train_token_ids: each element is a list of token ids
    Return:
        unigram_count: A map from token_id to occurrence
    """
    unigram_count = {}
    for i in range(len(train_token_ids)):
        for j in range(len(train_token_ids[i])):
            if unigram_count.get(train_token_ids[i][j],-1)==-1:
                unigram_count[train_token_ids[i][j]]=1
            else:
                unigram_count[train_token_ids[i][j]]+=1
    return unigram_count

In [13]:
def get_bigram_count(train_token_ids: List[List[int]]) -> Dict[int, Dict]:
    """ Calculate the occurrence of bigrams in the dataset.
    
    Args:
        train_token_ids: each element is a list of token ids
    Return:
        bigram_count: A map from token_id to next token occurrence. Key: token_id, value: Dict[token_id -> occurrence]
                      For example, {
                          5: { 10: 5, 20: 4 }
                      } means (5, 10) occurs 10 times, (5, 20) occurs 4 times.
    
                      (5,10)=5
                      (5,20)=4
                      5:{10:5,20:4}
    """
    bigram_count = {}
    pair_count={}
    bigram_count = defaultdict(lambda: defaultdict(int))
    for doc in train_token_ids:
        for i in range(len(doc)-1):
            if pair_count.get(tuple((doc[i],doc[i+1])),0)==0:
                pair_count[tuple((doc[i],doc[i+1]))] = 1
            else:
                pair_count[tuple((doc[i],doc[i+1]))] += 1
    
    ls=list(pair_count.keys())
    for i in range(len(ls)):
        k,v=ls[i]
        temp={v:pair_count[(k,v)]}
        if bigram_count.get(k,{})=={}:
            bigram_count[k]=temp
        else:
            temp1=bigram_count[k]
            temp1[v]=pair_count[(k,v)]
            bigram_count[k]=temp1
    
    return bigram_count

In [14]:
unigram_count = get_unigram_count(train_token_ids)
bigram_count = get_bigram_count(train_token_ids)

In [15]:
print(len(bigram_count[672]))

69


In [16]:
class BiGram:
    def __init__(self, unigram_count, bigram_count):
        self.unigram_count = unigram_count
        self.bigram_count = bigram_count
        
    def calc_prob(self, w1: int, w2: int) -> float:
        """
        
        Args:
            w1, w2: current token and next token
        Note:
            if prob you calculated is 0, you should return 1e-5.
        """
        count_w1 = self.unigram_count.get(w1, 0)
        count_w1_w2 = self.bigram_count.get(w1, {}).get(w2, 0)
    
        if count_w1 == 0:
            return 1e-5
        else:
            prob = count_w1_w2 / count_w1
            if prob == 0:
                return 1e-5
        return prob

###  Good Turing

In [18]:
from scipy.optimize import curve_fit


def power_law(x, a, b):
    """ Power law to fit the number of occurrence
    """
    return a * np.power(x, b)


class GoodTuring(BiGram):
    def __init__(self, unigram_count, bigram_count, threshold=100):
        super().__init__(unigram_count, bigram_count)
        self.threshold = threshold
        self.bigram_Nc = self.calc_Nc()
        self.bi_c_star, self.bi_N = self.smoothing(self.bigram_Nc)
        self.unigram_count=unigram_count

    def calc_Nc(self) -> Dict[int, Union[float, int]]:
        """ 
        
        Return:
            bigram_Nc: A map from count to the occurrence (count of count)
                       For example {
                           10: 78
                       } means there are 78 bigrams occurs 10 times in the dataset.
                       Also, 10 is a small c, for large c, it's occurrence will be replaced with the power law.
        """
        # Count the occurrence of count in self.bigram_count.
        bigram_Nc = defaultdict(int)
        for counts in self.bigram_count.values():
            for count in counts.values():
                bigram_Nc[count] += 1

        self.replace_large_c(bigram_Nc)
        return bigram_Nc

    def replace_large_c(self, Nc):
        """ Fit with power law
        """
        x, y = zip(*sorted(Nc.items(), reverse=True))
        popt, pcov = curve_fit(power_law, x, y, bounds=([0, -np.inf], [np.inf, 0]))
        a, b = popt

        max_count = max(Nc.keys())
        for c in range(self.threshold + 1, max_count + 2):
            Nc[c] = power_law(c, a, b)

    def smoothing(self, Nc: Dict[int, Union[float, int]]) -> Tuple[Dict[int, float], float]:
        """ Calculate the c_star and N
        
        Args:
            self.bigram_Nc
        Returns:
            c_star: The mapping from bigram count to smoothed count
            N: The sum of c multiplied by Nc
        """
        c_star = {}
        N = 0
        max_count = max(Nc.keys())
        for count in range(1, max_count + 1):
            if count not in Nc or count + 1 not in Nc:
                continue
            c_star[count] = (count + 1) * Nc[count + 1] / Nc[count]
            N += count * Nc[count]
        c_star[0] = Nc[1] / N
        return c_star, N

    def calc_prob(self, w1, w2):
        """         
        Good-turing smoothening
        Args:
            w1, w2: current token and next token
        """
        prob = 0
        # Get the unigram count of w1
        unigram_count_w1 = self.unigram_count.get(w1, 0)
        # Get the smoothed bigram count of (w1, w2)
        smoothed_bigram_count_w1_w2 = self.bi_c_star.get(self.bigram_count.get(w1, {}).get(w2, 0), 0)

        if unigram_count_w1 == 0:
            # w1 is not in the training data, so the probability is 0
            prob = 0
        else:
            if smoothed_bigram_count_w1_w2 == 0:
                # (w1, w2) is not in the training data, so use the unigram probability of w2
                if self.unigram_count.get(w2,0)==0:
                    self.unigram_count[w2]=1e-5
                    prob = 1e-5/len(np.sum(self.unigram_count.values()))
                else:
                    prob = self.unigram_count[w2]/len(np.sum(self.unigram_count.values()))
            else:
                prob = smoothed_bigram_count_w1_w2 / unigram_count_w1

        return prob

###  Kneser-Ney smoothening

In [19]:
class KneserNey(BiGram):
    def __init__(self, unigram_count, bigram_count, d=0.75):
        super().__init__(unigram_count, bigram_count)
        self.d = d
        
        self.lambda_ = self.calc_lambda()
        self.p_continuation = self.calc_p_continuation()
        
    def calc_lambda(self):
        """ 
        Return:
            lambda_: A dict from token_id (w) to λ(w).
        """
        lambda_ = {}
        for w1 in self.unigram_count:
            lambda_[w1] = self.d * len([w2 for w2 in self.bigram_count[w1] if self.bigram_count[w1][w2] > 0])
            lambda_[w1] /= self.unigram_count[w1]
        return lambda_
    
    def calc_p_continuation(self):
        """ 
        Return:
            lambda_: A dict from token_id (w) to λ(w).
        """
        numerator = {}  # token -> type of previous token
        denominator = len(self.bigram_count)  # type of all previous tokens
        for w1 in self.bigram_count:
            for w2 in self.bigram_count[w1]:
                if self.bigram_count[w1][w2] > 0:
                    numerator[w2] = numerator.get(w2, 0) + 1
        p_continuation = { 0: 0, 2: 0, 3: 0, 4: 0 }
        for w, count in numerator.items():
            p_continuation[w] = count / denominator
        return p_continuation
    
    def calc_prob(self, w1, w2):
        """ Calculate the probability of p(w2 | w1) using the Kneser-Ney model.
        
        Args:
            w1, w2: current token and next token
        """
        c_w1_w2 = self.bigram_count[w1][w2] if w1 in self.bigram_count and w2 in self.bigram_count[w1] else 0
        prob = max(c_w1_w2 - self.d, 0) / self.unigram_count[w1] + self.lambda_[w1] * self.p_continuation[w2]
        
        return prob

### Shows that perplexity is the exponential of the total loss divided by the number of predictions.

In [20]:
from math import log2
def perplexity(model, token_ids):
    """ 
    Args:
        model: the model you want to evaluate (BiGram, GoodTuring, or KneserNey)
        token_ids: a list of validation token_ids
    Return:
        perplexity: the perplexity of the model on texts      
    """
    log_probs = 0
    n = len(token_ids)
    n_words = 0
    for i, tokens in enumerate(token_ids):
        if i % 100 == 0 or i == n - 1:
            print_line('Calculating perplexity:', (i + 1), '/', n)
        log_prob = 0
        # Calculate the probability of each bigram
        for j in range(len(tokens) - 1):
            prob = model.calc_prob(tokens[j], tokens[j + 1])
            if prob == 0:
                prob = 1e-5  # handle zero probabilities
            log_prob += np.log(prob)
            n_words += 1
        log_probs += log_prob

    # Calculate the final perplexity
    avg_log_prob = log_probs / n_words
    perplexity = np.exp(-avg_log_prob)
    print('\n')
    
    return perplexity

In [21]:
bigram = BiGram(unigram_count, bigram_count)

# Perplexity
bigram_perplexity = perplexity(bigram, valid_token_ids)
print(f'The perplexity of Bigram is: {bigram_perplexity:.4f}')

Calculating perplexity: 3352 / 3352

The perplexity of Bigram is: 325.8354


In [22]:
gt = GoodTuring(unigram_count, bigram_count, threshold=100)

# Perplexity
gt_perplexity = perplexity(gt, valid_token_ids)
print(f'The perplexity of Good Turing is: {gt_perplexity:.4f}')

Calculating perplexity: 3352 / 3352

The perplexity of Good Turing is: 130.5334


In [23]:
kn = KneserNey(unigram_count, bigram_count, d=0.75)

# Perplexity
kn_perplexity = perplexity(kn, valid_token_ids)
print(f'The perplexity of Kneser-Ney is: {kn_perplexity:.4f}')

Calculating perplexity: 3352 / 3352

The perplexity of Kneser-Ney is: 62.5943


In [24]:
def predict(model: 'BiGram', w1: int, vocab_size: int):
    """ Predict the w2 with the hightest probability given w1
    
    Args:
        model: A BiGram, GoodTuring, or KneserNey model that has the calc_prob function
        w1: current word
        vocab_size: the number of tokens in the vocabulary
    """
    result = None
    highest_prob = 0
    #start your code
    for w2 in range(1, vocab_size):
        prob = model.calc_prob(w1, w2)
        if prob > highest_prob:
            highest_prob = prob
            result = w2
        # End
    return result

In [25]:
np.random.seed(12345)

vocab_size = len(tokenizer)
indexes = np.random.choice(len(test_token_ids), 10, replace=False)
for i in indexes:
    token_ids = test_token_ids[i][1:-1]
    print(' '.join([tokenizer.inverse_vocab[token_id] for token_id in token_ids]) + ' ____')
    pred = predict(gt, token_ids[-1], vocab_size)
    print(f'predicted last token: {tokenizer.inverse_vocab[pred]}')
    print('---------------------------------------------')

sharply falling stock prices do reduce consumer wealth damage business ____
predicted last token: </s>
---------------------------------------------
but robert an official of the association said no ____
predicted last token: longer
---------------------------------------------
it also has interests in military electronics and marine ____
predicted last token: s
---------------------------------------------
first chicago since n has reduced its loans to such ____
predicted last token: as
---------------------------------------------
david m jones vice president at g ____
predicted last token: s
---------------------------------------------
the n stock specialist firms on the big board floor ____
predicted last token: traders
---------------------------------------------
at the same time the business was hurt by ____
predicted last token: the
---------------------------------------------
salomon will cover the warrants by buying sufficient shares or ____
predicted last token: n
--------

In [26]:
def get_feature_label(token_ids: List[List[int]], window_size: int=-1):
    """ Split features and labels for the training, validation, and test datasets.
    
    Note:
        If window size is -1, for a sentence with n tokens,
            it selects the tokens rangeing from [0, n - 1) as the feature,
            and selects tokens ranging from [1, n) as the label.
        Otherwise, it divides a sentence with multiple windows and do the previous split.
    """
    x = []
    y = []
    seq_lens = []
    for sent_token_ids in token_ids:
        if window_size == -1:
            x.append(sent_token_ids[:-1])
            y.append(sent_token_ids[1:])
            seq_lens.append(len(sent_token_ids) - 1)
        else:
            if len(sent_token_ids) > window_size:
                sub_sent_size = window_size + 1
                n_window = len(sent_token_ids) // (sub_sent_size)
                for i in range(n_window):
                    start = i * sub_sent_size
                    sub_sent = sent_token_ids[start:(start + sub_sent_size)]
                    x.append(sub_sent[:-1])
                    y.append(sub_sent[1:])
                    seq_lens.append(len(sub_sent) - 1)
                if len(sent_token_ids) % sub_sent_size > 0:
                    sub_sent = sent_token_ids[-sub_sent_size:]
                    x.append(sub_sent[:-1])
                    y.append(sub_sent[1:])
                    seq_lens.append(len(sub_sent) - 1)
            else:
                x.append(sent_token_ids[:-1])
                y.append(sent_token_ids[1:])
                seq_lens.append(len(sent_token_ids) - 1)
    return x, y, seq_lens

In [27]:
window_size = 40
x_train, y_train, train_seq_lens = get_feature_label(train_token_ids, window_size)
x_valid, y_valid, valid_seq_lens = get_feature_label(valid_token_ids)
x_test, y_test, test_seq_lens = get_feature_label(valid_token_ids)
print(max(train_seq_lens), max(valid_seq_lens), max(test_seq_lens))

40 68 68


In [28]:
def pad_batch(x_batch: List[List[int]], y_batch: List[List[int]], seq_lens_batch: List[int], pad_val: int):
    """ Pad the sentences in a batch with pad_val based on the longest sentence.
    
    Args:
        x_batch, y_batch, seq_lens_batch: the input data
        pad_val: the padding value you need to fill to pad the sentences to the longest sentence.
        
    Return:
        x_batch: Tensor, (batch_size x max_seq_len)
        y_batch: Tensor, (batch_size x max_seq_len)
        seq_lens_batch: Tensor, (batch_size, )
    """
    max_len = max(seq_lens_batch)
    # Start your code here
    # Padding
    x_batch = [sentence + [pad_val] * (max_len - len(sentence)) for sentence in x_batch]
    y_batch = [sentence + [pad_val] * (max_len - len(sentence)) for sentence in y_batch]
    # End
    x_batch, y_batch = tf.convert_to_tensor(x_batch, dtype=tf.int64), tf.convert_to_tensor(y_batch, dtype=tf.int64)
    seq_lens_batch = tf.convert_to_tensor(seq_lens_batch, dtype=tf.int64)
    return x_batch, y_batch, seq_lens_batch

###  RNN language model

In [29]:
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense
from typing import Tuple

class RNN(Model):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        """ Init of the RNN model
        
        Args:
            vocab_size, embedding_dim: used for initialze the embedding layer.
            hidden_units: number of hidden units of the RNN layer.
        """
        super().__init__()
        # Start your code here
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.lstm = LSTM(hidden_units, return_sequences=True)
        self.dense = Dense(vocab_size)

        # End
        
    def call(self, x):
        """ Forward of the RNN model
        
        Args:
            x: Tensor, (batch_size x max_seq_len). Input tokens. Here, max_seq_len is the longest length of sentences in this batch becasue we did pad_batch.
        Return:
            outputs: Tensor, (batch_size x max_seq_len x vocab_size). Logits for every time step. !!!NO SOFTMAX HERE!!!
        """
        # Start your code here
        x = self.embedding(x)
        outputs = self.lstm(x)
        outputs = self.dense(outputs)
        # End
        return outputs

### Seq2seq loss

In [34]:
## deprecated

#from tensorflow_addon.seq2seq import sequence_loss
# def seq2seq_loss(logits, target, seq_lens):
#     """
#     Args:
#         logits: Tensor (batch_size x max_seq_len x vocab_size). The output of the RNN model.
#         target: Tensor (batch_size x max_seq_len). The groud-truth of words.
#         seq_lens: Tensor (batch_size, ). The real sequence length before padding.
#     """
#     loss = 0
#     mask = tf.sequence_mask(seq_lens, dtype=tf.float32)
#     loss = sequence_loss(logits, target, weights=mask)
#     return loss

def seq2seq_loss(logits, targets, seq_lens):
    """
    Custom implementation of sequence loss for handling sequence-to-sequence models.

    Args:
        logits: Tensor of shape (batch_size, max_seq_len, vocab_size) containing the logits.
        targets: Tensor of shape (batch_size, max_seq_len) containing the ground truth labels.
        seq_lens: Tensor of shape (batch_size,) containing the real lengths of each sequence.

    Returns:
        A scalar tensor containing the average loss per batch, considering only the real sequence lengths.
    """
    # Computing the cross entropy loss
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=logits)
    
    # Creating a mask for padding positions
    mask = tf.sequence_mask(lengths=seq_lens, maxlen=tf.shape(targets)[1], dtype=tf.float32)
    
    # Applying the mask to the loss
    masked_loss = loss * mask
    
    # Calculating the average loss over all batches
    loss_sum = tf.reduce_sum(masked_loss)
    total_elements = tf.reduce_sum(mask)
    average_loss = loss_sum / total_elements
    
    return average_loss

In [51]:
vocab_size = len(tokenizer)
hidden_units = 256
embedding_dim = 512
num_epoch = 10
batch_size = 64

In [52]:
model = RNN(vocab_size, embedding_dim, hidden_units)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

### Training RNNs

In [53]:
num_samples = len(x_train)
n_batch = int(np.ceil(num_samples / batch_size))
n_valid_batch = int(np.ceil(len(x_valid) / batch_size))
for epoch in range(num_epoch):
    epoch_loss = 0.0
    for batch_idx in range(n_batch):
        start = batch_idx * batch_size
        end = start + batch_size
        x_batch, y_batch, seq_lens_batch = x_train[start:end], y_train[start:end], train_seq_lens[start:end]
        real_batch_size = len(x_batch)
        x_batch, y_batch, seq_lens_batch = pad_batch(x_batch, y_batch, seq_lens_batch, pad_val=tokenizer.pad_token_id)

        with tf.GradientTape() as tape:
            output = model(x_batch)
            loss = seq2seq_loss(output, y_batch, seq_lens_batch)

        if batch_idx % 1 == 0 or batch_idx == num_samples - 1:
            print_line(f'Epoch {epoch + 1} / {num_epoch} - Step {batch_idx + 1} / {n_batch} - loss: {loss:.4f}')
            
        trainable_vars = model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        optimizer.apply_gradients(zip(gradients, trainable_vars))
        epoch_loss += loss * real_batch_size
    
    valid_loss = 0.0
    for batch_idx in range(n_valid_batch):
        start = batch_idx * batch_size
        end = start + batch_size
        x_batch, y_batch, seq_lens_batch = x_valid[start:end], y_valid[start:end], valid_seq_lens[start:end]
        real_batch_size = len(x_batch)
        x_batch, y_batch, seq_lens_batch = pad_batch(x_batch, y_batch, seq_lens_batch, pad_val=tokenizer.pad_token_id)
        output = model(x_batch)
        loss = seq2seq_loss(output, y_batch, seq_lens_batch)

        if batch_idx % 1 == 0 or batch_idx == len(x_valid) - 1:
            print_line(f'Epoch {epoch + 1} / {num_epoch} - Step {batch_idx + 1} / {n_valid_batch} - loss: {loss:.4f}')

        valid_loss += loss * real_batch_size
    print(f'\rEpoch {epoch + 1} / {num_epoch} - Step {n_batch} / {n_batch} - train loss: {epoch_loss / num_samples:.4f} - valid loss: {valid_loss / len(x_valid):.4f}')

Epoch 1 / 10 - Step 677 / 677 - train loss: 6.4266 - valid loss: 5.9178
Epoch 2 / 10 - Step 677 / 677 - train loss: 5.7559 - valid loss: 5.5953
Epoch 3 / 10 - Step 677 / 677 - train loss: 5.4572 - valid loss: 5.4244
Epoch 4 / 10 - Step 677 / 677 - train loss: 5.2477 - valid loss: 5.3129
Epoch 5 / 10 - Step 677 / 677 - train loss: 5.0824 - valid loss: 5.2377
Epoch 6 / 10 - Step 677 / 677 - train loss: 4.9441 - valid loss: 5.1845
Epoch 7 / 10 - Step 677 / 677 - train loss: 4.8225 - valid loss: 5.1491
Epoch 8 / 10 - Step 677 / 677 - train loss: 4.7121 - valid loss: 5.1264
Epoch 9 / 10 - Step 677 / 677 - train loss: 4.6099 - valid loss: 5.1142
Epoch 10 / 10 - Step 677 / 677 - train loss: 4.5146 - valid loss: 5.1113


### Perplexity of RNN 

In [55]:
n = len(x_valid)
log_probs = 0
n_words = 0  # number of words to predict in the entire dataset
total_loss = 0  # total loss of each word's loss
for i in range(n):
    if i % 1 == 0 or i == n - 1:
        print_line('Calculating perplexity:', (i + 1), '/', n)
    x_line, y_line, line_seq_lens = x_valid[i:i + 1], y_valid[i: i + 1], valid_seq_lens[i:i + 1]
    x_line, y_line, line_seq_lens = pad_batch(x_line, y_line, line_seq_lens, tokenizer.pad_token_id)
    output = model(x_line)
    pred_probs = tf.nn.softmax(output, axis=-1)

    for real_token, probs in zip(y_line[0], pred_probs[0]):
        log_prob = tf.math.log(tf.clip_by_value(probs[real_token], 1e-9, 1.0))
        log_probs += log_prob
        n_words += 1
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_line, logits=output))
    total_loss += loss * tf.size(y_line[0])
print('\n')
perplexity = 2 ** ((-1 / n_words) * log_probs)
print(f'Perplexity by definition: {perplexity:.4f}, Perplexity by loss: {np.exp(total_loss / n_words):.4f}')

# If you implement correctly, the two perplexity will be almost the same.

Calculating perplexity: 1 / 3352

InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a float tensor but is a int32 tensor [Op:Mul]

###  Predicting the next word given a previous sentence

In [None]:
np.random.seed(12345)

vocab_size = len(tokenizer)
indexes = np.random.choice(len(test_token_ids), 10, replace=False)
for i in indexes:
    token_ids = test_token_ids[i][1:-1]
    print(' '.join([tokenizer.inverse_vocab[token_id] for token_id in token_ids]) + ' ____')
    x = tf.convert_to_tensor(token_ids, dtype=tf.int64)  # now x is a tensor of (seq_len, )
    # Start your code here


    # End
    print(f'predicted last token: {tokenizer.inverse_vocab[pred]}')
    print('---------------------------------------------')

Briefly analyze the result of N-Gram and RNN