<a href="https://colab.research.google.com/github/ihatethinkingname/dl_practice/blob/main/sasrec_practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import copy
import math

# class SASRec(nn.Module):
#   def __init__(self, item_num) -> None:
#     super().__init__()

class EnD(nn.Module):
  def __init__(self,src_embed,encoder,decoder,tgt_embed,generator):
    super().__init__()
    self.encoder=encoder
    self.decoder=decoder
    self.src_embed=src_embed
    self.tgt_embed=tgt_embed
    self.generator=generator

  def forward(self,src,tgt,src_mask,tgt_mask):
    return EnDcoder(EnDcoder(src,src_mask),src_mask,tgt,tgt_mask)

  def encode(self,src,src_mask):
    return EnDcoder(self.src_embed(src),src_mask)

  def decode(self,memory,src_mask,tgt,tgt_mask):
    return EnDcoder(memory,src_mask,self.tgt_embed(tgt),tgt_mask)

class Embeddings(nn.Module):
  def __init__(self,vocab_size,d_model):
    super().__init__()
    self.lut=nn.Embedding(vocab_size,d_model)
    self.d_model=d_model

  def forward(self,x):
    return self.lut(x) * math.sqrt(self.d_model)

class Generator(nn.Module):
  def __init__(self,d_model,vocab_size):
    super().__init__()
    self.proj=nn.Linear(d_model,vocab_size)

  def forward(self,x):
    return F.log_softmax(self.proj(x),dim=-1)

class EnDcoder(nn.Module):
  def __init__(self,layer,N):
    super().__init__()
    self.layers=clone(layer,N)
    self.norm=nn.LayerNorm(layer.d_model)

  def forward(self,x,src_mask,tgt=None,tgt_mask=None):
    for layer in self.layers:
      x=layer(x,src_mask,tgt,tgt_mask)
    return self.norm(x)

def clone(layer,N):
  return nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])

class EnDcoderLayer(nn.Module):
  def __init__(self,self_attn,ffn,d_model,drop_rate,src_attn=None):
    super().__init__()
    self.self_attn=self_attn
    self.ffn=ffn
    self.connector=clone(SublayerConnection(d_model,drop_rate),2)
    self.d_model=d_model
    if src_attn is not None:
      self.src_attn=src_attn
      self.connector.append(SublayerConnection(d_model,drop_rate))

  def forward(self,x,src_mask,tgt=None,tgt_mask=None):
    if isEncoder(tgt,tgt_mask):
      x=self.connector[0](x,lambda x: self.self_attn(x,x,x,src_mask))
    else:
      m=x         # m is memoery
      x=tgt        # We keep using x as the thing go through the network
      x=self.connector[0](x,lambda x: self.self_attn(x,x,x,tgt_mask))
      x=self.connector[1](x,lambda x: self.src_attn(x,m,m,src_mask))
    return self.connector[-1](x,self.ffn)

def isEncoder(tgt,tgt_mask):
    # Encoder model
    if tgt==tgt_mask==None:
      return True
    # Decoder model
    elif tgt is not None and tgt_mask is not None:
      return False
    # Wrong input
    else:
      raise ValueError("Encoder instance input error: "\
               "tgt and tgt_mask should be both None or both not None.")

class SublayerConnection(nn.Module):
  def __init__(self,d_model,drop_rate):
    super().__init__()
    self.norm=nn.LayerNorm(d_model)
    self.dropout=nn.Dropout(drop_rate)

  def forward(self,x,sublayer):
    return x+self.dropout(sublayer(self.norm(x)))

class FFN(nn.Module):
  def __init__(self,d_model,d_ff,dropout_rate=0.1):
    super().__init__()
    self.L1=nn.Linear(d_model,d_ff)
    self.dropout=nn.Dropout(0.1)
    self.L2=nn.Linear(d_ff,d_model)

  def forward(self,x)
    return self.L2(self.dropout(F.relu(self.L1(x))))
