In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class JDMR(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers, num_users, num_items, num_classes, dropout_rate):
        super(JDMR, self).__init__()
        
        self.user_emb = nn.Embedding(num_users, emb_size)
        self.item_emb = nn.Embedding(num_items, emb_size)
        self.word_emb = nn.Embedding(vocab_size, emb_size, padding_idx=0)
        
        self.rnn = nn.GRU(emb_size, hidden_size, num_layers, batch_first=True, dropout=dropout_rate, bidirectional=True)
        
        self.user_mlp = nn.Sequential(nn.Linear(hidden_size*2, hidden_size),
                                       nn.Tanh(),
                                       nn.Dropout(dropout_rate))
        self.item_mlp = nn.Sequential(nn.Linear(hidden_size*2, hidden_size),
                                       nn.Tanh(),
                                       nn.Dropout(dropout_rate))
        
        self.attn_mlp = nn.Sequential(nn.Linear(hidden_size*2, 1),
                                       nn.Tanh())
        
        self.fc = nn.Linear(hidden_size*4, num_classes)
        
    def forward(self, user_inputs, item_inputs, review_inputs, review_lengths):
        
        user_emb = self.user_emb(user_inputs)  # (batch_size, emb_size)
        item_emb = self.item_emb(item_inputs)  # (batch_size, emb_size)
        
        review_emb = self.word_emb(review_inputs)  # (batch_size, seq_len, emb_size)
        
        rnn_output, _ = self.rnn(review_emb)  # (batch_size, seq_len, hidden_size*2)
        
        # Attention Mechanism
        attn_weights = self.attn_mlp(rnn_output)  # (batch_size, seq_len, 1)
        attn_weights = attn_weights.squeeze(-1)  # (batch_size, seq_len)
        attn_weights = F.softmax(attn_weights, dim=-1)  # (batch_size, seq_len)
        attn_weights = attn_weights.unsqueeze(-1)  # (batch_size, seq_len, 1)
        
        review_emb_weighted = torch.sum(attn_weights * rnn_output, dim=1)  # (batch_size, hidden_size*2)
        
        # User and Item Representation
        user_rep = self.user_mlp(user_emb)  # (batch_size, hidden_size)
        item_rep = self.item_mlp(item_emb)  # (batch_size, hidden_size)
        
        # Concatenate
        x = torch.cat([user_rep, item_rep, review_emb_weighted], dim=1)  # (batch_size, hidden_size*4)
        
        # Final Prediction
        x = self.fc(x)  # (batch_size, num_classes)
        
        return x
