In [None]:
import torch
import torch.nn as nn
import torch.optim as optim


import math
import re
from random import *

import numpy as np

In [None]:
text = (
    'Hey there, how’s it going? I’m Alex.\n'
    'Hi Alex! I’m Taylor. Nice to meet you.\n'
    'Nice to meet you too. How’s your day been?\n'
    'Pretty good, thanks for asking. How about yours?\n'
    'It’s been busy, but good overall.\n'
    'That’s great to hear. I just finished a painting.\n'
    'Wow, that’s awesome, Taylor! I’d love to see it sometime.\n'
    'Sure thing! Maybe next time we meet.\n'
    'Speaking of which, are you free this weekend?\n'
    'I think so. Do you have plans?\n'
    'I was thinking of checking out that new cafe downtown.\n'
    'That sounds fun! Count me in.\n'
    'Great, let’s meet there at 2 PM on Saturday.\n'
    'Perfect! I’ll see you then, Alex.\n'
    'See you, Taylor. Have a good night!\n'
    'You too, Alex. Good night!\n'
    'On Saturday, Alex and Taylor met at the cafe.\n'
    'This place has such a cozy vibe, doesn’t it?\n'
    'Definitely. The decor is really nice too.\n'
    'Yeah, and the coffee smells amazing.\n'
    'They enjoyed their drinks and chatted for hours.\n'
    'Thanks for hanging out today, Taylor.\n'
    'No problem, Alex. I had a blast.\n'
    'Let’s do this again soon.\n'
    'For sure! I’m looking forward to it.\n'
    'Until next time, Taylor.\n'
    'Until next time, Alex.'
)



In [None]:
import json

with open('conversation.json', 'r') as file:
    text_lines = json.load(file)
text = '\n'.join(text_lines) 
text

In [None]:
sentences = re.sub('[.,!?\\-]', '', text.lower()).split('\n')
word_list = list(set(' '.join(sentences).split()))
word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}

for i, w in enumerate(word_list):
    word_dict[w] = i + 4

number_dict = {i:w for i, w in enumerate(word_dict)}
vocab_size = len(word_dict)


token_list = list()

for sentence in sentences:
    arr = [word_dict[s] for s in sentence.split()]
    token_list.append(arr)

In [None]:
max_len = 30
batch_size = 6
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768 * 4
d_k = 64
d_v = 64
n_segments = 2

max_pred = 3

In [None]:
def select_random_sentence_pair(sentences, token_list):
    tokens_a_index = randrange(len(sentences))
    tokens_b_index = randrange(len(sentences))
    tokens_a = token_list[tokens_a_index]
    tokens_b = token_list[tokens_b_index]
    return tokens_a, tokens_b, tokens_a_index, tokens_b_index


sentences = ["The cat sat on the mat.", "The dog barked at the cat.", "The bird sang a song."]
token_list = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14]]


tokens_a, tokens_b, tokens_a_index, tokens_b_index = select_random_sentence_pair(sentences, token_list)
print(tokens_a, tokens_b, tokens_a_index, tokens_b_index)

In [None]:
def construct_input_segment_ids(tokens_a, tokens_b, word_dict):
    input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
    segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
    # segment ids represents which tokens belong to the first 
    # sentence and which belong to the second sentence
    # 0 first sentence 1 second sentence
    return input_ids, segment_ids

input_ids, segment_ids = construct_input_segment_ids(tokens_a,
                                                     tokens_b,
                                                     word_dict)


print(input_ids, '\n', segment_ids)

In [None]:
def mask_tokens(input_ids, word_dict, max_pred, vocab_size, number_dict):
    n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
    cand_makes_pos = [i for i, token in enumerate(input_ids) if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
    shuffle(cand_makes_pos)

    masked_tokens, masked_pos = [], []
    for pos in cand_makes_pos[:n_pred]:
        masked_pos.append(pos)
        masked_tokens.append(input_ids[pos])
        if random() < 0.8:
            input_ids[pos] = word_dict['[MASK]']
        elif random() < 0.5:
            index = randint(0, vocab_size - 1)
            input_ids[pos] = word_dict[number_dict[index]]
    
    return masked_tokens, masked_pos, input_ids




masked_tokens, masked_pos, input_ids = mask_tokens(
    input_ids, word_dict, max_pred, vocab_size, number_dict
)

print('masked_tokens', masked_tokens)
print('masked_pos', masked_pos)
print('input_ids', input_ids)

In [None]:
def pad_sequences(input_ids, segment_ids, masked_tokens, masked_pos, max_len, max_pred):
    n_pad = max_len - len(input_ids)
    input_ids.extend([0] * n_pad)
    segment_ids.extend([0] * n_pad)

    if max_pred > len(masked_tokens):
        n_pad = max_pred - len(masked_tokens)
        masked_tokens.extend([0] * n_pad)
        masked_pos.extend([0] * n_pad)

    return input_ids, segment_ids, masked_tokens, masked_pos

In [None]:
input_ids, segment_ids, masked_tokens, masked_pos = pad_sequences(
    input_ids, segment_ids, masked_tokens, masked_pos, max_len, max_pred
)

In [None]:
def is_positive_example(tokens_a_index, tokens_b_index):
    return tokens_a_index + 1 == tokens_b_index

In [None]:
def make_batch(sentences, token_list, word_dict, number_dict, batch_size, max_len,
               max_pred, vocab_size):
    positive = negative = 0
    batch = []

    while positive != batch_size / 2 or negative != batch_size / 2:
        tokens_a, tokens_b, tokens_a_index, tokens_b_index = \
        select_random_sentence_pair(sentences, token_list)

        input_ids, segment_ids = construct_input_segment_ids(tokens_a, tokens_b, word_dict)

        masked_tokens, masked_pos, input_ids = mask_tokens(
            input_ids, word_dict, max_pred, vocab_size, number_dict
        )

        input_ids, segment_ids, masked_tokens, masked_pos = pad_sequences(
            input_ids, segment_ids, masked_tokens, masked_pos, max_len, max_pred
        )

        if is_positive_example(tokens_a_index, tokens_b_index) and positive < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
            positive += 1
        elif not is_positive_example(tokens_a_index, tokens_b_index) and negative < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
            negative += 1
    
    return batch

In [None]:
batch = make_batch(sentences, token_list, word_dict, number_dict,
                   batch_size, max_len, max_pred, vocab_size)

In [None]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(* batch))

In [None]:
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


In [None]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    return pad_attn_mask.expand(batch_size, len_q, len_k)

In [None]:
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()

        self.tok_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.seg_embed = nn.Embedding(n_segments, d_model)

        self.norm = nn.LayerNorm(d_model)


    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype = torch.long)
        pos = pos.unsqueeze(0).expand_as(x)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

In [None]:
attn_mask = get_attn_pad_mask(input_ids, input_ids)


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_k, d_v):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.w_q = nn.Linear(d_model, d_k * n_heads)
        self.w_k = nn.Linear(d_model, d_k * n_heads)
        self.w_v = nn.Linear(d_model, d_v * n_heads)
        self.fc = nn.Linear(n_heads * d_v, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, q, k, v, attn_mask):
        residual = q
        batch_size = q.size(0)
        q_s = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k_s = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v_s = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
        scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(self.d_k)
        scores.masked_fill_(attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1), -1e9)
        attn_scores = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn_scores, v_s)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
        output = self.fc(context)
        output = self.layer_norm(output + residual)
        return output, attn_scores
    

In [None]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(gelu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()

        self.enc_self_attn = MultiHeadAttention(d_model, n_heads, d_k, d_v)
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn
    



In [None]:
enc = EncoderLayer()


emb = Embedding()

embeds = emb(input_ids, segment_ids)

attn_mask = get_attn_pad_mask(input_ids, segment_ids)


mha_output = MultiHeadAttention(d_model, n_heads, d_k, d_v)(embeds, embeds, embeds, attn_mask)

output, a = mha_output


output = enc.forward(enc_inputs = embeds, enc_self_attn_mask = attn_mask)

In [None]:
class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()

        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])


        self.fc = nn.Linear(d_model, d_model)
        self.activ1 = nn.Tanh()

        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        #NSP

        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()

        self.decoder = nn.Linear(n_dim, n_vocab, bias = False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        # here we choosing the 0 index 
        # because while attention mechanism computation
        # cls token is containing all information about attention scores between words
        
        h_pooled = self.activ1(self.fc(output[:, 0]))
        logits_clsf = self.classifier(h_pooled)

        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
        # we expanding masked pos to be compatible for broadcasting

        h_masked = torch.gather(output, 1, masked_pos)



        h_masked = self.norm(self.activ2(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias
        return logits_lm, logits_clsf
    

bert = BERT()

bert.forward(input_ids = input_ids, segment_ids = segment_ids, masked_pos = masked_pos)

In [None]:
optimizer = torch.optim.Adam(bert.parameters(), lr = 0.00001)

criterion = nn.CrossEntropyLoss()

In [None]:
for epoch in range(300):
    optimizer.zero_grad()

    logits_lm, logits_clsf = bert(input_ids, segment_ids, masked_pos)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens)

    loss_lm = (loss_lm.float()).mean()
    loss_clsf = criterion(logits_clsf, isNext)
    
    loss = loss_lm + loss_clsf
    if (epoch + 1) % 10 == 0:
        print('epoch: [{}] cost: [{}]'.format(epoch, loss))
    loss.backward()
    optimizer.step()


In [None]:
import torch

test_sentences = [
    "Romeo loves Juliet.",
    "Juliet loves Romeo.",
    "Romeo and Juliet are together."
]

test_token_list = []
for sentence in test_sentences:
    arr = [word_dict.get(s, word_dict['[MASK]']) for s in sentence.lower().split()]
    test_token_list.append(arr)

tokens_a, tokens_b, _, _ = select_random_sentence_pair(test_sentences, test_token_list)

input_ids, segment_ids = construct_input_segment_ids(tokens_a, tokens_b, word_dict)

masked_tokens, masked_pos, input_ids = mask_tokens(
    input_ids, word_dict, max_pred, vocab_size, number_dict
)

input_ids, segment_ids, masked_tokens, masked_pos = pad_sequences(
    input_ids, segment_ids, masked_tokens, masked_pos, max_len, max_pred
)

input_ids_tensor = torch.LongTensor([input_ids])
segment_ids_tensor = torch.LongTensor([segment_ids])
masked_pos_tensor = torch.LongTensor([masked_pos])

logits_lm, logits_clsf = bert(input_ids=input_ids_tensor, segment_ids=segment_ids_tensor, masked_pos=masked_pos_tensor)

predicted_tokens = logits_lm.data.max(2)[1].numpy()[0]
predicted_masked_tokens = [number_dict[pos] for pos in predicted_tokens if pos != 0]

predicted_isNext = logits_clsf.data.max(1)[1].numpy()[0]

print("Original Sentences:")
print("Sentence A:", " ".join([number_dict[token] for token in tokens_a]))
print("Sentence B:", " ".join([number_dict[token] for token in tokens_b]))
print("\nPredicted Masked Tokens:", predicted_masked_tokens)
print("Predicted isNext:", "True" if predicted_isNext else "False")
