<a href="https://colab.research.google.com/github/csch7/CSCI-4170/blob/main/Homework-05/NLP_and_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn

In [2]:
def softmax(scores):
  return torch.exp(scores) / torch.sum(torch.exp(scores), dim = 0)

def scaled_dot_product_attention(queries, keys, values):
  queries = torch.unsqueeze(queries, 2).repeat((1,1,64))
  scores = (queries @ keys.permute(1,2,0)) / np.sqrt(keys.shape[-1])
  s = softmax(scores)
  return torch.squeeze(s @ values.permute(1,0,2), 1)


# class DotProductAttention(nn.Layer):
#   def __init__(self):
#     super().__init__()

#   def call(self, queries, keys, values):


In [3]:
class Scaled_Dot_Product_Attention(nn.Module):
  def __init__(self):
    super().__init__()
    self.sm = nn.Softmax(dim=0)

  def forward(self, queries, keys, values):
    queries = torch.unsqueeze(queries, 2).repeat((1,1,64))
    print(queries.shape, keys.shape, values.shape)
    scores = (queries.permute(1,0,2) @ keys.permute(1,2,0)) / np.sqrt(keys.shape[-1])
    s = softmax(scores)
    return torch.squeeze(s @ values.permute(1,0,2), 1)


class Encoder(nn.Module):
  def __init__(self, input_dim, embed_dim, hidden_dim):
    super().__init__()
    self.embed = nn.Embedding(input_dim, embed_dim)
    self.lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=True)
    self.fc = nn.Linear(2*hidden_dim, hidden_dim)
    self.tanh = nn.Tanh()

  def forward(self, x):
    em = self.embed(x)
    lstm_out, (hidden, cell) = self.lstm(em)
    return lstm_out, self.tanh(self.fc(torch.cat((hidden[0,:,:], hidden[1,:,:]), dim=1))).unsqueeze(0), cell

class Decoder(nn.Module):
  def __init__(self, embed_dim, hidden_dim, output_dim):
    super().__init__()
    self.embed = nn.Embedding(output_dim, embed_dim)
    self.lstm = nn.LSTM(2*hidden_dim + embed_dim, hidden_dim)
    self.attn = Scaled_Dot_Product_Attention()
    self.fc = nn.Linear(hidden_dim, output_dim)
    self.sm = nn.Softmax(dim=2)

  def forward(self, targets, hidden, cell, encoder_out):
    em = self.embed(targets)
    # print(hidden.shape)
    attn = self.attn(hidden[:,:,-1], encoder_out, encoder_out)
    print(hidden.shape, cell.shape)
    lstm_out, (hidden, _) = self.lstm(torch.unsqueeze(torch.cat((em, attn), dim=1),0), (hidden, cell))
    return self.fc(lstm_out), hidden, cell

class Seq2SeqAttn(nn.Module):
  def __init__(self, enc, dec, out_vocab_len):
    super().__init__()
    self.encoder = enc
    self.decoder = dec
    self.vocab_len = out_vocab_len

  def forward(self, inputs, targets):
    tar_len = targets.shape[0]
    tar_size = targets.shape[1]
    outputs = torch.zeros((tar_len, tar_size, self.vocab_len))
    enc_out, hidden, cell = self.encoder(inputs)
    for i in range(1, tar_len):
      dec_out, hidden, cell = self.decoder(targets[i], hidden, cell, enc_out)
      outputs[i] = dec_out
    return outputs


In [4]:
!pip install datasets



In [5]:
from datasets import load_dataset

ds = load_dataset('bentrevett/multi30k')

train_dat = ds['train'][:len(ds['train'])//70]
valid_dat = ds['validation'][:len(ds['validation'])//70]
test_dat = ds['test'][:len(ds['test'])//70]
train_lab = train_dat['en']
train_dat = train_dat['de']
valid_lab = valid_dat['en']
valid_dat = valid_dat['de']
test_lab = test_dat['en']
test_dat = test_dat['de']

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [6]:
import re

def clean_text(text):
    text = str(text).lower() # Ensure no duplicate word embeddings due to capital letters
    test = re.sub(r'^[A-Za-zÀ-ȕ ]+', '', text)         # Remove certain special characters (need to be careful not to remove umlauds or eszetts from German)
    text = re.sub(r"\s+", " ", text).strip()      # Remove extra spaces
    return text

def pad_sentences(dat, max_len):
  for s in range(len(dat)):
    if len(dat[s]) > max_len:
      dat[s] = dat[s][:max_len]
    else:
      dat[s] = dat[s] + ['<PAD>']*(max_len-len(dat[s]))
  return dat

def process_sentences(dat, vocab, max_len):
  dat = [s for s in dat]
  dat = [['<SOS>']+[clean_text(si) for si in s.split()]+['<EOS>'] for s in dat]
  dat = pad_sentences(dat, max_len)
  dat = [[vocab[word] for word in s] for s in dat]
  return dat


max_len = 50

sentences_en = [s for ds in [train_lab, valid_lab, test_lab] for s in ds]
sentences_en = [['<SOS>']+[clean_text(si) for si in s.split()]+['<EOS>'] for s in sentences_en]
vocab_en = set([w for s in sentences_en for w in s])
vocab_en = {word: idx+1 for idx, word in enumerate(vocab_en)}
vocab_en['<PAD>'] = 0
token_to_value_en = {vocab_en[k]: k for k in vocab_en}

sentences_de = [s for ds in [train_dat, valid_dat, test_dat] for s in ds]
sentences_de = [['<SOS>']+[clean_text(si) for si in s.split()]+['<EOS>'] for s in sentences_de]
vocab_de = set([w for s in sentences_de for w in s])
vocab_de = {word: idx+1 for idx, word in enumerate(vocab_de)}
vocab_de['<PAD>'] = 0
token_to_value_de = {vocab_de[k]: k for k in vocab_de}

train_dat = process_sentences(train_dat, vocab_de, max_len)
train_lab = process_sentences(train_lab, vocab_en, max_len)
valid_dat = process_sentences(valid_dat, vocab_de, max_len)
valid_lab = process_sentences(valid_lab, vocab_en, max_len)
test_dat = process_sentences(test_dat, vocab_de, max_len)
test_lab = process_sentences(test_lab, vocab_en, max_len)

In [7]:
import torch.optim as optim

def one_hot_encode(labels, max_len, vocab_size):
  res = torch.zeros((len(labels), max_len, vocab_size))
  for i in range(len(labels)):
    for j in range(max_len):
      res[i,j,labels[i,j]] = 1
  return res


train_dat = torch.LongTensor(train_dat)
train_lab = torch.LongTensor(train_lab)
train_ohe = one_hot_encode(train_lab, max_len, len(vocab_en))
valid_dat = torch.LongTensor(valid_dat)
valid_lab = torch.LongTensor(valid_lab)
valid_ohe = one_hot_encode(valid_lab, max_len, len(vocab_en))
test_dat = torch.LongTensor(test_dat)
test_lab = torch.LongTensor(test_lab)
test_ohe = one_hot_encode(test_lab, max_len, len(vocab_en))

epochs = 10
enc = Encoder(len(vocab_de), 100, 32)
dec = Decoder(100, 32, len(vocab_en))
model = Seq2SeqAttn(enc, dec, len(vocab_en))
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

for e in range(epochs):
  model.train()
  optimizer.zero_grad()
  pred = model(train_dat.T, train_lab.T)
  pred = torch.permute(pred, (1,0,2))
  print(pred[0])
  print([token_to_value_en[int(w)] for w in train_lab[0]], [token_to_value_en[int(w)] for w in torch.argmax(pred[0], dim=1)])
  loss = loss_fn(pred, train_ohe)
  print(loss.item())
  loss.backward()
  optimizer.step()

  model.eval()


torch.Size([1, 414, 64]) torch.Size([50, 414, 64]) torch.Size([50, 414, 64])
torch.Size([1, 414, 32]) torch.Size([2, 414, 32])


RuntimeError: Expected hidden[1] size (1, 414, 32), got [2, 414, 32]

In [10]:
from math import inf

class PositionalEncoding(nn.Module):
  def __init__(self, seq_len, embedding_dim):
    super().__init__()
    self.seq_len = seq_len
    self.embed_dim = embedding_dim

  def forward(self, x):
    embedding = torch.zeros(self.seq_len, x.shape[1], self.embed_dim)
    positions = torch.arange(self.seq_len)
    for p in positions:
      embedding[p, :, ::2] = torch.sin(p/(10000**(2*torch.arange(self.embed_dim)[:self.embed_dim//2]/self.embed_dim)))
      embedding[p, :, 1::2] = torch.cos(p/(10000**(2*torch.arange(self.embed_dim)[:self.embed_dim//2]/self.embed_dim)))
    return embedding


class ScaledDotProductAttention(nn.Module):
  def __init__(self, masking = False):
    super().__init__()
    self.sm = nn.Softmax(dim=0)
    self.mask = masking

  def forward(self, q, k, v):
    scores = q @ k.permute(0,2,1) / np.sqrt(k.shape[-1])
    if self.mask:
      # print(q, k, v)
      mask = torch.full(scores.shape, -1*10^30)
      mask = torch.triu(mask, 1)
      scores = scores + mask
    s = self.sm(scores)
    # if self.mask:
    #   print(s, s@v)
    return s @ v

class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, qk, qv, dim_model, masking = False):
    super().__init__()
    self.i = 0
    self.nh = num_heads
    self.d_model = dim_model
    self.mask = masking
    self.Wq = nn.Parameter(torch.randn((num_heads, dim_model, qk)))
    self.Wk = nn.Parameter(torch.randn((num_heads, dim_model, qk)))
    self.Wv = nn.Parameter(torch.randn((num_heads, dim_model, qv)))
    self.Wo = nn.Parameter(torch.randn((num_heads*qv, dim_model)))
    self.attn = ScaledDotProductAttention(masking)

  def forward(self, Q, K, V):
    # print(self.i, V)
    self.i += 1
    output = self.attn(Q @ self.Wq[0], K @ self.Wk[0], V @ self.Wv[0])
    for i in range(1, self.nh):
      output = torch.cat((output, self.attn(Q @ self.Wq[i], K @ self.Wk[i], V @ self.Wv[i])), dim=2)
    return output @ self.Wo


class FFN(nn.Module):
  def __init__(self, embedding_dim = 64, hidden_dim = 128):
    super().__init__()
    self.w1 = nn.Parameter(torch.randn((embedding_dim, hidden_dim)))
    self.b1 = nn.Parameter(torch.randn(hidden_dim))
    self.w2 = nn.Parameter(torch.randn((hidden_dim, embedding_dim)))
    self.b2 = nn.Parameter(torch.randn(embedding_dim))
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.relu(x @ self.w1 + self.b1) @ self.w2 + self.b2


class Encoder(nn.Module):
  def __init__(self, input_len, vocab_size, d_model = 64, hidden_dim = 128, num_heads = 8, num_layers = 2):
    super().__init__()
    self.d_model = d_model
    self.hidden_dim = hidden_dim
    self.embed = nn.Embedding(vocab_size, d_model)
    self.position = PositionalEncoding(input_len, d_model)
    self.ffn = FFN(d_model, hidden_dim)
    self.layernorm = nn.LayerNorm((input_len, d_model))
    self.attn = MultiHeadAttention(num_heads, int(d_model / num_heads), int(d_model / num_heads), d_model)
    self.L = num_layers

  def forward(self, inputs):
    em = self.embed(inputs)
    pos_en = self.position(inputs)
    out = (em + pos_en).permute(1,0,2)

    for l in range(self.L):
      self_attn = self.attn(out, out, out)
      attn_norm = self.layernorm(self_attn + out)
      ffn_out = self.ffn(attn_norm)
      out = self.layernorm(ffn_out + attn_norm)
    return out


class Decoder(nn.Module):
  def __init__(self, output_len, vocab_size, d_model = 64, hidden_dim = 128, num_heads = 8, num_layers = 2):
    super().__init__()
    self.d_model = d_model
    self.hidden_dim = hidden_dim
    self.embed = nn.Embedding(vocab_size, d_model)
    self.position = PositionalEncoding(output_len, d_model)
    self.ffn = FFN(d_model, hidden_dim)
    self.layernorm = nn.LayerNorm((output_len, d_model))
    self.attn = MultiHeadAttention(num_heads, int(d_model / num_heads), int(d_model / num_heads), d_model)
    self.masked_attn = MultiHeadAttention(num_heads, int(d_model / num_heads), int(d_model / num_heads), d_model, masking=True)
    self.L = num_layers

  def forward(self, outputs, enc_out):
    em = self.embed(outputs)
    pos_en = self.position(outputs)
    out = (em + pos_en).permute(1,0,2)
    for l in range(self.L):
      # print(out)
      self_attn = self.masked_attn(out, out, out)
      attn_norm = self.layernorm(self_attn + out)
      enc_attn = self.attn(enc_out, enc_out, attn_norm)
      # print(l, out)
      # print(enc_attn.shape, attn_norm.shape)
      attn_norm = self.layernorm(enc_attn + attn_norm)
      ffn_out = self.ffn(attn_norm)
      out = self.layernorm(ffn_out + attn_norm)
    return out


class Transformer(nn.Module):
  def __init__(self, input_len, output_len, in_vocab_size, out_vocab_size, d_model = 64, hidden_dim = 128, num_heads = 8, num_layers = 2):
    super().__init__()
    self.encoder = Encoder(input_len, in_vocab_size)
    self.decoder = Decoder(output_len, out_vocab_size)
    self.fc = nn.Linear(d_model, out_vocab_size)
    self.sm = nn.Softmax(dim=0)

  def forward(self, inputs, outputs):
    enc_out = self.encoder(inputs)
    dec_out = self.decoder(outputs, enc_out)
    return self.sm(self.fc(dec_out))

In [None]:
import torch.optim as optim

def one_hot_encode(labels, max_len, vocab_size):
  res = torch.zeros((len(labels), max_len, vocab_size))
  for i in range(len(labels)):
    for j in range(max_len):
      res[i,j,labels[i,j]] = 1
  return res


train_dat = torch.LongTensor(train_dat)
train_lab = torch.LongTensor(train_lab)
train_ohe = one_hot_encode(train_lab, max_len, len(vocab_en))
valid_dat = torch.LongTensor(valid_dat)
valid_lab = torch.LongTensor(valid_lab)
valid_ohe = one_hot_encode(valid_lab, max_len, len(vocab_en))
test_dat = torch.LongTensor(test_dat)
test_lab = torch.LongTensor(test_lab)
test_ohe = one_hot_encode(test_lab, max_len, len(vocab_en))

epochs = 1000
model = Transformer(max_len-1, max_len-1, len(vocab_de),len(vocab_en))
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for e in range(epochs):
  model.train()
  # for s in range(train_dat.shape[0]):
  optimizer.zero_grad()
  pred = model(train_dat[:,1:].T, train_lab[:,:-1].T)
  # print(pred.shape)
  print([token_to_value_en[int(w)] for w in train_lab[0]], [token_to_value_en[int(w)] for w in torch.argmax(pred[0], dim=1)])
  loss = loss_fn(pred, train_ohe[:,1:,:])
  print(loss.item())
  loss.backward()
  optimizer.step()

  model.eval()

['<SOS>', 'two', 'young,', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes.', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>'] ['talking', 'around', 'around', 'house', 'wedding', 'pink.', 'pink.', 'stick', 'talking', 'talking', 'pink.', 'talking', 'pink.', 'dressed', 'pink.', 'around', 'talking', 'around', 'around', 'around', 'around', 'around', 'derby', 'talking', 'around', 'talking', 'around', 'pink.', 'around', 'around', 'pink.', 'around', 'around', 'around', 'pink.', 'around', 'around', 'safety.', 'around', 'around', 'around', 'around', 'around', 'around', 'pile', 'pink.', 'around', 'pink.', 'around']
0.14579153060913086
['<SOS>', 'two', 'young,', 'white', 'males',