In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class DotProductAttention(nn.Module):  
    def __init__(self, dropout):
        """Scaled dot product attention."""
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values):
        d = queries.shape[-1]
        self.attention_weights = torch.matmul(
            queries, 
            keys.T
        ) / math.sqrt(d)
        return torch.matmul(
            self.dropout(self.attention_weights), 
            values
        )

In [2]:
attention = DotProductAttention(0.1)

In [3]:
user = torch.randn(256, 64)
item = torch.randn(256, 64)

In [4]:
attention(user, item, item).shape

torch.Size([256, 64])

In [5]:
class AddNorm(nn.Module):
    def __init__(self, embed_dim, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(embed_dim)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(embed_dim, bias=bias)
        self.W_k = nn.LazyLinear(embed_dim, bias=bias)
        self.W_v = nn.LazyLinear(embed_dim, bias=bias)
        self.W_o = nn.LazyLinear(embed_dim, bias=bias)

    def forward(self, queries, keys, values):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
        output = self.attention(
            queries, 
            keys, 
            values
        )
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)
    
    def transpose_qkv(self, X):
        """
        Transposition for parallel computation of multiple attention heads.
        
        Parameters
        ----------
        @param X: torch.Tensor
            Shape 
            (
                batch_size, 
                num_hiddens
            ).
        @return X: torch.Tensor
            Shape 
            (
                batch_size * num_heads, 
                num_hiddens / num_heads
            )
        """
        X = X.reshape(X.shape[0], self.num_heads, -1)
        return X.reshape(-1, X.shape[2])

    def transpose_output(self, X):
        """
        Reverse the operation of transpose_qkv.
        
        Parameters
        ----------
        @param X: torch.Tensor
            Shape 
            (
                batch_size * num_heads, 
                num_hiddens / num_heads
            ).
        @return X: torch.Tensor
            Shape 
            (
                batch_size, 
                num_hiddens
            )
        """
        X = X.reshape(-1, self.num_heads, X.shape[1])
        return X.reshape(X.shape[0], -1)

In [6]:
attention = MultiHeadAttention(64, 8, 0.5)



In [7]:
attention(user, item, item).shape

torch.Size([256, 64])

In [8]:
class SelfBlock(nn.Module):  
    def __init__(self, embed_dim, num_heads, dropout,
                 use_bias=False):
        super().__init__()
        self.addnorm1 = AddNorm(embed_dim, dropout)
        self.attention = MultiHeadAttention(
            embed_dim, num_heads,
            dropout, use_bias
        )
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 2 * embed_dim, bias=use_bias),
            nn.ReLU(),
            nn.Linear(2 * embed_dim, embed_dim, bias=use_bias),
        )
        self.addnorm2 = AddNorm(embed_dim, dropout)

    def forward(self, X):
        Y = self.addnorm1(X, self.attention(X, X, X))
        return self.addnorm2(Y, self.ffn(Y))

In [9]:
uu_block = SelfBlock(64, 8, 0.5)

In [10]:
uu_block(user).shape

torch.Size([256, 64])

In [11]:
ii_block = SelfBlock(64, 8, 0.5)

In [12]:
ii_block(item).shape

torch.Size([256, 64])

In [13]:
class CrossBlock(nn.Module):  
    def __init__(self, embed_dim, num_heads, dropout,
                 use_bias=False):
        super().__init__()
        self.addnorm1 = AddNorm(embed_dim, dropout)
        self.attention = MultiHeadAttention(
            embed_dim, num_heads,
            dropout, use_bias
        )
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 2 * embed_dim, bias=use_bias),
            nn.ReLU(),
            nn.Linear(2 * embed_dim, embed_dim, bias=use_bias),
        )
        self.addnorm2 = AddNorm(embed_dim, dropout)

    def forward(self, user, item):
        Y = self.addnorm1(item, self.attention(item, user, user))
        return self.addnorm2(Y, self.ffn(Y))

In [14]:
ui_block = CrossBlock(64, 8, 0.5)

In [15]:
ui_block(user, item).shape

torch.Size([256, 64])

In [16]:
from gcn.transformerf import TransforMerF
model = TransforMerF(
    n_users=1000,
    m_items=1000,
    embed_dim=64
)

In [22]:
user = torch.randint(0,1000,(256,))
item = torch.randint(0,1000,(256,))
user.shape, item.shape

(torch.Size([256]), torch.Size([256]))

In [24]:
model(user, item)

tensor([-1.4154e+01, -1.7320e-01, -1.9497e+00,  1.9831e+00,  1.2043e+01,
        -1.0175e+01, -5.1725e+00, -1.0127e+01, -5.5775e+00,  3.1541e+00,
         3.8758e-01, -3.3462e+00,  2.0182e+01, -2.5343e+00,  8.0695e+00,
        -1.1938e+01, -2.6288e+00, -1.3323e+01,  3.8617e+00, -6.0513e+00,
        -7.6396e+00,  1.0637e+01, -1.5773e+00,  4.6440e-01, -7.2481e+00,
         5.6005e+00, -6.9028e+00, -7.3112e+00, -4.6141e+00,  9.4800e+00,
         1.0274e+01, -2.4179e+01,  7.1802e+00,  2.4965e+00, -5.0181e+00,
        -5.3158e+00,  8.6317e+00,  2.8862e-01,  4.7187e+00, -7.6584e+00,
         9.7339e+00,  1.1335e+01,  6.3949e+00,  8.3217e+00,  1.1241e+01,
        -2.0072e+01,  2.6638e+00, -7.7120e+00,  1.5399e+01,  4.9124e+00,
        -7.2256e+00, -1.7304e+00,  4.0716e+00, -8.7875e+00, -6.0653e+00,
         9.4789e-01,  3.6162e+00,  1.5173e+01, -5.8039e+00, -5.9595e+00,
        -4.9906e+00,  2.3655e+00,  2.1212e+00,  1.0127e+01,  9.7590e+00,
         7.9508e+00, -1.9721e+00,  1.6780e+00, -5.8