In [None]:
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# def get_vocab(vocab_path):
#   with open(vocab_path,'r') as f:
#     vocab = json.load(f)
#     vocab.append('<SOS>')
#     vocab.append('<EOS>')
#     stoi = {s:(i+1) for i,s in enumerate(vocab)}
#     itos = {(i+1):s for i,s in enumerate(vocab)}
#     stoi['<PAD>'] = 0
#     itos[0] = '<PAD>'
#     return stoi,itos

# stoi_en,itos_en = get_vocab(r'/content/drive/MyDrive/2024Spring/641NaturalLanguageProcessing/NewFolder/vocab_en.json')
# stoi_zh,itos_zh = get_vocab(r'/content/drive/MyDrive/2024Spring/641NaturalLanguageProcessing/NewFolder/vocab_zh.json')

In [None]:
# # hyperparameters
# # model
# vocab_size_en = len(stoi_en)
# vocab_size_zh = len(stoi_zh)
# max_length = 32         # max length of the input sequence
# n_emb = 8               # embedding size
# n_head = 2              # number of heads in multi-head attention
# head_size = 4           # number of 'features' output by a single-head self-attention
# n_blocks = 1            # number of blocks in a encoder or decoder
# n_hidden = 2048
# assert head_size*n_head == n_emb, ''

# # training
# batch_size = 32
# learning_rate = 1e-3
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class embedding(nn.Module):
  def __init__(self,vocab_size,n_emb,max_len):
    super().__init__()
    self.n_emb = n_emb

    self.word_embedding = nn.Embedding(vocab_size,n_emb)

    pe = torch.zeros(max_len, n_emb)
    position = torch.unsqueeze(torch.arange(0, max_len, dtype=torch.float),dim=1)
    div_term = torch.exp(torch.arange(0, n_emb, 2).float() * (-math.log(10000.0) / n_emb))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)  # Add batch dimension
    self.register_buffer('pe', pe)

  def forward(self,x):
    word_emb = self.word_embedding(x) * math.sqrt(self.n_emb)         # [B,T,n_emb]
    pos_emb = self.pe[:,:word_emb.size(1),:]                           # [T,n_emb]
    return pos_emb + word_emb                                         # [B,T,n_emb]

In [None]:
# # test embedding
# x = torch.randint(low=0,high=20,size=(2,5))
# emb = embedding(vocab_size_en,n_emb,max_length)
# out = emb(x)
# out.shape

torch.Size([2, 5, 8])
torch.Size([1, 32, 8])


torch.Size([2, 5, 8])

In [None]:
# single-head self-attention
class selfAttention(nn.Module):
  def __init__(self,fan_in,fan_out,masked):
    super().__init__()
    self.query = nn.Linear(fan_in,fan_out)
    self.key = nn.Linear(fan_in,fan_out)
    self.value = nn.Linear(fan_in,fan_out)
    self.masked = masked
  def forward(self,x):
    B,T,C = x.shape
    q = self.query(x)   # [B,T,fan_out]
    k = self.key(x)     # [B,T,fan_out]
    v = self.value(x)   # [B,T,fan_out]
    attention_score = q @ torch.transpose(k,dim0=1,dim1=2) * C**(-0.5)  # [B,T,T]
    if self.masked:
      mask = torch.tril(attention_score)
      attention_score = attention_score.masked_fill(mask==0,float('-inf'))
      attention_score = F.softmax(attention_score,dim=-1)
    out = attention_score @ v # [B,T,fan_out]
    return out

In [None]:
# single-head unmasked cross-attention for decoder
class crossAttention(nn.Module):
  def __init__(self,fan_in,fan_out):
    super().__init__()
    self.query = nn.Linear(fan_in,fan_out)
    self.key = nn.Linear(fan_in,fan_out)
    self.value = nn.Linear(fan_in,fan_out)
  def forward(self,x,cross):
    B,T,C = x.shape
    q = self.query(x)
    k = self.key(cross)
    v = self.value(cross)
    attention_score = q @ torch.transpose(k,dim0=1,dim1=2) * C**(-0.5)
    out = attention_score @ v
    return out

In [None]:
# multi-head self-attention
class multiHead_sa(nn.Module):
  def __init__(self,fan_in,fan_out,masked,n_head):
    super().__init__()
    self.multi_head = nn.ModuleList([selfAttention(fan_in,fan_out,masked)]*n_head)

  def forward(self,x):
    out = None
    for i,head in enumerate(self.multi_head):
      if i == 0:
        out = head(x)
      else:
        out = torch.concat([out,head(x)],dim=-1)

    return out

In [None]:
# multi-head cross-attention
class multiHead_ca(nn.Module):
  def __init__(self,fan_in,fan_out,n_head):
    super().__init__()
    self.multi_head = nn.ModuleList([crossAttention(fan_in,fan_out)]*n_head)

  def forward(self,x,cross):
    out = None
    for i,head in enumerate(self.multi_head):
      if i == 0:
        out = head(x,cross)
      else:
        out = torch.concat([out,head(x,cross)],dim=-1)
    return out

In [None]:
# encoder block
class encoderBlock(nn.Module):
  def __init__(self,fan_in,fan_out,n_head):
    super().__init__()
    self.multi_head = multiHead_sa(fan_in,fan_out,False,n_head)
    self.layerNorm = nn.LayerNorm(fan_in)
    self.ffw = nn.Linear(fan_in,fan_in)
    self.relu = nn.ReLU()

  def forward(self,x):
    out = x + self.relu(self.multi_head(x))
    out = self.layerNorm(out)
    out = out + self.relu(self.ffw(out))
    out = self.layerNorm(out)
    return out

In [None]:
# decoder block
class decoderBlock(nn.Module):
  def __init__(self,fan_in,fan_out,n_head):
    super().__init__()
    self.masked_multihead = multiHead_sa(fan_in,fan_out,True,n_head)
    self.layerNorm = nn.LayerNorm(fan_in)
    self.cross_multihead = multiHead_ca(fan_in,fan_out,n_head)
    self.ffw = nn.Linear(fan_in,fan_in)
    self.relu = nn.ReLU()

  def forward(self,x,cross):
    out = x + self.relu(self.masked_multihead(x))
    out = self.layerNorm(out)
    out = out + self.relu(self.cross_multihead(out,cross))
    out = self.layerNorm(out)
    out = out + self.relu(self.ffw(out))
    out = self.layerNorm(out)
    return out

In [None]:
# encoder
class Encoder(nn.Module):
  def __init__(self,n_emb,head_size,n_head,n_blocks):
    super().__init__()
    self.blocks = nn.ModuleList([encoderBlock(n_emb,head_size,n_head)]*n_blocks)

  def forward(self,x):
    crosses = []
    for block in self.blocks:
      crosses.append(block(x))
    return crosses

In [None]:
# decoder
class Decoder(nn.Module):
  def __init__(self,n_emb,head_size,n_head,n_blocks,vocab_size):
    super().__init__()
    self.blocks = nn.ModuleList([decoderBlock(n_emb,head_size,n_head)] * n_blocks)
    self.linear = nn.Linear(n_emb,vocab_size)
    self.relu = nn.ReLU()

  def forward(self,x,crosses):
    for i,cross in enumerate(crosses):
      if i == 0:
        out = self.blocks[i](x,cross)
      else:
        out = self.blocks[i](out,cross)
    out = self.relu(self.linear(out))
    return out

In [None]:
# transformer encoder-decoder
class Transformer(nn.Module):
  def __init__(self,n_emb,head_size,n_head,n_blocks,vocab_size_enc,vocab_size_dec,max_len):
    super().__init__()
    self.embedding_enc = embedding(vocab_size_enc,n_emb,max_len)
    self.embedding_dec = embedding(vocab_size_dec,n_emb,max_len)
    self.encoder = Encoder(n_emb,head_size,n_head,n_blocks)
    self.decoder = Decoder(n_emb,head_size,n_head,n_blocks,vocab_size_dec)
    self.max_len = max_len

  def forward(self,seq_enc,seq_dec):
    emb_enc = self.embedding_enc(seq_enc)
    emb_dec = self.embedding_dec(seq_dec)
    crosses = self.encoder(emb_enc)
    out = self.decoder(emb_dec,crosses)
    return out

  def generate(self,seq_enc,seq_dec):
    seq_enc = torch.unsqueeze(seq_enc,dim=0)
    while itos_zh[seq_dec[-1].item()] != '<EOS>':
      seq_dec = torch.unsqueeze(seq_dec,dim=0)
      out = self(seq_enc,seq_dec) # [1,T_dec,vocab_size]
      out = out[:,-1,:]           # [1,vocab_size]
      prob = F.softmax(out,-1)
      out = torch.multinomial(prob,num_samples=1,replacement=True)  # [1,1]
      seq_dec = torch.concat([seq_dec,out],dim=-1).squeeze()  # [T_dec+1]
      if seq_dec.shape[0] >= self.max_len:
        break

    return seq_dec

In [None]:
# # test selfAttention
# x = torch.randn((2,5,32))
# sa_masked = selfAttention(32,8,True)
# sa_unmasked = selfAttention(32,8,False)
# out_masked = sa_masked(x)
# out_unmasked = sa_unmasked(x)
# out_masked.shape,out_unmasked.shape

In [None]:
# # test multi-head self-attention
# x = torch.randn((2,5,32))
# mhsa = multiHead_sa(32,8,True,n_head)
# out = mhsa(x)
# out.shape

In [None]:
# # test multi-head cross-attention
# x = torch.randn((2,5,32))
# cross = torch.randn((2,5,32))
# mhca = multiHead_ca(32,8,4)
# out = mhca(x,cross)
# out.shape

In [None]:
# # test encoder block
# x = torch.randn((2,5,32))
# en_block = encoderBlock(32,8,4)
# out = en_block(x)
# out.shape

In [None]:
# # test decoder block
# x = torch.randn((2,5,32))
# cross = torch.randn((2,5,32))
# de_block = decoderBlock(32,8,4)
# out = de_block(x,cross)
# out.shape

In [None]:
# # test Encoder
# x = torch.randn((2,5,32))
# encoder = Encoder(32,8,4,3)
# crosses = encoder(x)
# len(crosses)

In [None]:
# # test Decoder
# x = torch.randn((2,5,32))
# decoder = Decoder(32,8,4,3,vocab_size_zh)
# out = decoder(x,crosses)
# out.shape

In [None]:
# # test Transformer
# seq_enc = torch.randint(0,100,(2,4))
# seq_dec = torch.randint(0,100,(2,7))
# transformer = Transformer(n_emb,head_size,n_head,n_blocks,vocab_size_en,vocab_size_zh,max_length)
# out = transformer(seq_enc,seq_dec)

In [None]:
class TorchTransformer(nn.Module):
  def __init__(self,n_emb,head_size,n_head,n_blocks,vocab_size_enc,vocab_size_dec,n_hidden,max_len):
    super().__init__()
    self.embedding_enc = embedding(vocab_size_enc,n_emb,max_len)
    self.embedding_dec = embedding(vocab_size_dec,n_emb,max_len)
    self.transformer = nn.Transformer(d_model=n_emb,nhead=n_head,num_encoder_layers=n_blocks,num_decoder_layers=n_blocks,dim_feedforward=n_hidden,batch_first=True)
    self.linear = nn.Linear(n_emb,vocab_size_dec)
    self.max_len = max_len

  def forward(self,seq_enc,seq_dec,mask_enc=None,mask_dec=None,mask_enc_padding=None,mask_dec_padding=None,memory_key_padding_mask=None):
    emb_enc = self.embedding_enc(seq_enc)
    emb_dec = self.embedding_dec(seq_dec)
    out = self.transformer(src=emb_enc,tgt=emb_dec,
                           src_mask=mask_enc,tgt_mask=mask_dec,
                           src_key_padding_mask=mask_enc_padding,tgt_key_padding_mask=mask_dec_padding,
                           memory_key_padding_mask=memory_key_padding_mask)
    out = self.linear(out)
    return out

  def encode(self,seq_enc,mask_enc=None):
    emb_enc = self.embedding_enc(seq_enc)
    return self.transformer.encoder(emb_enc,mask_enc)

  def decode(self,seq_dec,memory,mask_dec):
    emb_dec = self.embedding_dec(seq_dec)
    return self.transformer.decoder(emb_dec,memory,mask_dec)

  def generate(self,seq_enc,seq_dec,test=False):
    seq_enc = torch.unsqueeze(seq_enc,dim=0)
    while itos_zh[seq_dec[-1].item()] != '<EOS>':
      seq_dec = torch.unsqueeze(seq_dec,dim=0)
      out = self(seq_enc,seq_dec) # [1,T_dec,vocab_size]
      out = out[:,-1,:]           # [1,vocab_size]
      prob = F.softmax(out,-1)
      out = torch.multinomial(prob,num_samples=1)        # [1,1]
      temp = seq_dec
      seq_dec = torch.concat([seq_dec,out],dim=-1).squeeze()  # [T_dec+1]

      if test:
        print('Use ')
        print('input sequence: ', seq_enc)
        print('output sequence: ', temp)
        print('to predict', out)
        print('***************************')
      if seq_dec.shape[0] >= self.max_len:
        break

    return seq_dec

In [None]:
# # test TorchTransformer
# import spacy
# nlp = spacy.load('en_core_web_sm')

# def generate(input_seq,test=False):
#   tokens = nlp(input_seq)
#   tokens = [stoi_en['<SOS>']] + [stoi_en[token.text.lower()] for token in tokens] + [stoi_en['<EOS>']]
#   seq_enc = torch.tensor(tokens).to(device)
#   seq_dec = torch.tensor([stoi_zh['<SOS>']]).to(device)
#   output = model.generate(seq_enc,seq_dec,test)
#   output = [itos_zh[o.item()] for o in output][1:]
#   return ''.join(output)

# model = TorchTransformer(n_emb,head_size,n_head,n_blocks,vocab_size_en,vocab_size_zh,n_hidden,max_length)
# input_seq = "harry potter"
# output_seq = generate(input_seq,test=True)