In [2]:
import argparse
import math
import os
import sys
import time

import numpy as np
from loguru import logger
from collections import defaultdict


import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from transformers import AutoModelForSequenceClassification
from transformers import TFAutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModel

from dataset.dataset_dbpedia import DBpedia
from dataset.load_data import SentDataLoader, RecDataLoader

In [4]:
import easydict


args = easydict.EasyDict({
    'seed': 42,
    'output_dir': 'save',
    'output_dataset_dir': 'data/redial/',
    'debug': False,
    'dataset': 'data/redial/',
    'kg_dataset': 'data/dbpedia/',
    
    'context_max_length': 256,
    'entity_max_length': 50,
    'num_workers': 0,
    'senti_tokenizer': 'roberta-base',
    'text_tokenizer': 'roberta-base',
    'text_encoder': 'roberta-base',
    'num_bases': 8,
    
    'num_epochs': 30, #10,
    'batch_size': 64, #64, #256,
    'entity_hidden_size': 128,  
    
    'gradient_accumulation_steps': 1,
    'learning_rate': 1e-3,
    'weight_decay': 0.01,
    'max_grad_norm': None,
    'num_warmup_steps': None,
    'device': 'cuda:1' if torch.cuda.is_available() else 'cpu',
})

print(args.device)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print(">>>>>>>>>>> CUDA available :", torch.cuda.is_available())


cuda:1
>>>>>>>>>>> CUDA available : True


In [None]:
from transformers import pipeline
 
sentiment_model = pipeline(model="federicopascual/finetuning-sentiment-model-3000-samples")
sentiment_model(["I love this move", "This movie sucks!"])

In [46]:
# from evaluate_rec import RecEvaluator
# from model_gpt2 import PromptGPT2forCRS
# from config import gpt2_special_tokens_dict, prompt_special_tokens_dict
# from model_prompt import KGPrompt

###############################################################################
## Environment Setting
###############################################################################

def set_seed(seed):
    """
        Setting random seeds
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    
set_seed(args.seed)

# ###############################################################################
# ## Load Data & KG
# ###############################################################################

kg = DBpedia(dataset=args.dataset).get_entity_kg_info()


task='sentiment'
MODEL = f"cardiffnlp/twitter-roberta-base-{task}"

sent_model = AutoModelForSequenceClassification.from_pretrained(MODEL)
sent_model = sent_model.to(args.device)

# Load the tokenizer
senti_tokenizer = AutoTokenizer.from_pretrained(MODEL)

sent_train_dataloader, sent_val_dataloader, sent_test_dataloader = SentDataLoader(args, senti_tokenizer)
#train_dataloader, val_dataloader, test_dataloader = RecDataLoader(args, senti_tokenizer, pad_entity_id=kg['pad_entity_id'])

# pad_entity_id = kg['pad_entity_id']
# train_dataset = CRSRecDataset(
#             dpath=args.dataset, split='train', 
#             tokenizer=senti_tokenizer, context_max_length=args.context_max_length,
#         )

# valid_dataset = CRSRecDataset(
#         dpath=args.dataset, split='valid', 
#         tokenizer=senti_tokenizer, context_max_length=args.context_max_length,
#     )

# test_dataset = CRSRecDataset(
#         dpath=args.dataset, split='test', 
#         tokenizer=senti_tokenizer, context_max_length=args.context_max_length,
#     )

# data_collator = CRSRecDataCollator(
#         tokenizer=senti_tokenizer, context_max_length=args.context_max_length,
#         pad_entity_id=pad_entity_id, device=args.device
#     )

# train_dataloader = DataLoader(
#         train_dataset,
#         batch_size=args.batch_size,
#         collate_fn=data_collator,
#         shuffle=True
#     )
# val_dataloader = DataLoader(
#         valid_dataset,
#         batch_size=args.batch_size,
#         collate_fn=data_collator,
#     )
# test_dataloader = DataLoader(
#         test_dataset,
#         batch_size=args.batch_size,
#         collate_fn=data_collator,
#     )

>>>>>>>>>>> CUDA available : True


2023-12-12 13:53:12.170 | DEBUG    | dataset.dataset_dbpedia:_process_entity_kg:40 - #edge: 186339, #relation: 25, #entity: 30890, #item: 6281


  0%|          | 0/110206 [00:00<?, ?it/s]

  0%|          | 0/12284 [00:00<?, ?it/s]

  0%|          | 0/15452 [00:00<?, ?it/s]

  0%|          | 0/101514 [00:00<?, ?it/s]

  0%|          | 0/11294 [00:00<?, ?it/s]

  0%|          | 0/14174 [00:00<?, ?it/s]

In [80]:
for sent_batch in sent_train_dataloader:
    print("!")
    break

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


!




In [85]:
for batch in train_dataloader:
    print("!")
    break

!




In [94]:
train_ent_sent['17839']

{19174: tensor(0.9979, device='cuda:1'),
 21083: tensor(0.9976, device='cuda:1'),
 19018: tensor(0.9987, device='cuda:1'),
 12533: tensor(0.9987, device='cuda:1')}

In [90]:
batch['conv_id'][2], batch['entity'][2]

('17839',
 tensor([19018,  1771, 12533, 30889, 30889, 30889, 30889, 30889, 30889, 30889,
         30889, 30889, 30889, 30889, 30889, 30889, 30889], device='cuda:1'))

In [123]:
entity_lists = batch['entity']

user_repr_list = []

for i, entity_list in enumerate(entity_lists):
    if i == 2:
        # pad_entity_id를 제외한 실제 entity ID만 필터링
        filtered_entity_list = [entity for entity in entity_list if entity != kg['pad_entity_id']]

        if not filtered_entity_list:
            # 실제 엔티티가 없는 경우
            user_repr_list.append(torch.zeros(128))#args.kg_emb_dim))
            continue

        filtered_entity_tensor = torch.tensor(filtered_entity_list, dtype=torch.long, )#device=self.device)
        user_repr = kg_embedding[filtered_entity_tensor] # (n_e, dim)

In [130]:
train_ent_sent['17839']

{19174: tensor(0.9979, device='cuda:1'),
 21083: tensor(0.9976, device='cuda:1'),
 19018: tensor(0.9987, device='cuda:1'),
 12533: tensor(0.9987, device='cuda:1')}

In [97]:
batch['conv_id'][2], batch['entity'][2], filtered_entity_tensor

('17839',
 tensor([19018,  1771, 12533, 30889, 30889, 30889, 30889, 30889, 30889, 30889,
         30889, 30889, 30889, 30889, 30889, 30889, 30889], device='cuda:1'),
 tensor([19018,  1771, 12533]))

In [111]:
ent_sent_dict

{19174: tensor(0.9979, device='cuda:1'),
 21083: tensor(0.9976, device='cuda:1'),
 19018: tensor(0.9987, device='cuda:1'),
 12533: tensor(0.9987, device='cuda:1')}

In [None]:
ent_sent_dict = self.ent_sent[conv_id]
for i, entity_id in enumerate(entity_list):
    if entity_id in ent_sent_dict:
        sentiment_score = ent_sent_dict[entity_id]

In [188]:
conv_id = "17839"
ent_sent_dict = train_ent_sent[conv_id]
for i, entity_id in enumerate(filtered_entity_tensor.tolist()):
    if entity_id in ent_sent_dict:
        sentiment_score = ent_sent_dict[entity_id]
        print(entity_id, sentiment_score)

19018 tensor(0.9987, device='cuda:1')
12533 tensor(0.9987, device='cuda:1')


In [171]:
class SelfAttentionBatch(nn.Module):
    def __init__(self, dim, da, alpha=0.2, dropout=0.5):
        super(SelfAttentionBatch, self).__init__()
        self.dim = dim
        self.da = da
        self.alpha = alpha
        self.dropout = dropout
        self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
        self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        nn.init.xavier_uniform_(self.b.data, gain=1.414)

    def forward(self, h, entity_list, train_ent_sent):
        # h: (N, dim) => user가 상호작용한 item/entity 개수만큼의 행렬
        
        e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1)  # (N)
        print(e.shape, e)
        
        attention = F.softmax(e, dim=0)  # (N)
        print(attention.shape, attention)
        
        conv_id = "17839"
        ent_sent_dict = train_ent_sent[conv_id]
        for i, entity_id in enumerate(entity_list.tolist()):
            if entity_id in ent_sent_dict:
                sent_score = ent_sent_dict[entity_id]
                attention[i] *= sent_score.detach().cpu()
            
        print(attention.shape, attention)
        
        # 조정된 attention 가중치를 정규화
        norm_attention = attention / attention.sum()
        print(norm_attention.shape, norm_attention)
        
        return torch.matmul(norm_attention, h)  # (dim)

In [156]:
# # train_ent_sent["17839"]

# from copy import deepcopy

# copy_train_ent_sent = deepcopy(train_ent_sent)

In [153]:
# copy_train_ent_sent["17839"][12533] = torch.tensor(0.2312, device=args.device)
copy_train_ent_sent["17839"]

{19174: tensor(0.9979, device='cuda:1'),
 21083: tensor(0.9976, device='cuda:1'),
 19018: tensor(0.9987, device='cuda:1'),
 12533: tensor(0.2312, device='cuda:1')}

In [170]:
copy_train_ent_sent

defaultdict(dict,
            {'14536': {17897: tensor(0.9613, device='cuda:1'),
              26306: tensor(0.9987, device='cuda:1'),
              27698: tensor(0.9987, device='cuda:1')},
             '1516': {18488: tensor(0.9987, device='cuda:1'),
              28626: tensor(0.9988, device='cuda:1'),
              26306: tensor(0.9988, device='cuda:1'),
              25107: tensor(0.9988, device='cuda:1'),
              28695: tensor(0.9980, device='cuda:1')},
             '5458': {12533: tensor(0.9982, device='cuda:1'),
              2115: tensor(0.9980, device='cuda:1'),
              7102: tensor(0.9968, device='cuda:1')},
             '5219': {20183: tensor(0.9967, device='cuda:1'),
              16507: tensor(0.9965, device='cuda:1'),
              25417: tensor(0.9974, device='cuda:1'),
              26306: tensor(0.9861, device='cuda:1'),
              8565: tensor(0.9966, device='cuda:1')},
             '8126': {26306: tensor(0.9991, device='cuda:1'),
              9090: te

In [142]:
filtered_entity_tensor

tensor([19018,  1771, 12533])

In [180]:

class ApplySentSelfAttentionBatch(nn.Module):
    def __init__(self, dim, da, alpha=0.2, dropout=0.5): #, ent_sent=None):
        super(ApplySentSelfAttentionBatch, self).__init__()
        self.dim = dim
        self.da = da
        self.alpha = alpha
        self.dropout = dropout
        
        # self.ent_sent = ent_sent
        self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
        self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        nn.init.xavier_uniform_(self.b.data, gain=1.414)

    def forward(self, h, entity_list, conv_id, ent_sent):
        # h: (N, dim) => user가 상호작용한 item/entity 개수만큼의 행렬
        e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1)  # (N)
        attention = F.softmax(e, dim=0)  # (N)
        
        ent_sent_dict = ent_sent[conv_id]
        for i, entity_id in enumerate(entity_list.tolist()):
            if entity_id in ent_sent_dict:
                sent_score = ent_sent_dict[entity_id]
                attention[i] *= sent_score.detach().cpu()
        
        print(attention)
        
        # attention 가중치를 정규화
        norm_attention = attention / attention.sum()
        
        print(norm_attention)
        
        return torch.matmul(norm_attention, h)  # (dim)

In [187]:
filtered_entity_tensor

tensor([19018,  1771, 12533])

In [189]:
0.04008 + 0.437960

0.47804

In [190]:
0.037360 + 0.441232

0.478592

In [181]:
attn = ApplySentSelfAttentionBatch(128, 128)

# (user_repr, user_entity_tensor_list, conv_ids[i], ent_sent)

attn_user_repr = attn(user_repr, filtered_entity_tensor, conv_id="17839", ent_sent=copy_train_ent_sent)
attn_user_repr.shape

tensor([0.2120, 0.6119, 0.0407], grad_fn=<CopySlices>)
tensor([0.2452, 0.7078, 0.0470], grad_fn=<DivBackward0>)


torch.Size([128])

In [179]:
copy_train_ent_sent["17839"]

{19174: tensor(0.9979, device='cuda:1'),
 21083: tensor(0.9976, device='cuda:1'),
 19018: tensor(0.9987, device='cuda:1'),
 12533: tensor(0.2312, device='cuda:1')}

In [178]:
filtered_entity_tensor

tensor([19018,  1771, 12533])

In [53]:
class EntitySentModel(nn.Module):
    def __init__(self, model):
        super(EntitySentModel, self).__init__()
        
        self.model = model
        self.ent_sent = defaultdict(dict)

    def forward(self, batch):
        input_ids, attention_mask = batch["context"]['input_ids'], batch["context"]['attention_mask']
        
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
            
        logits = logits.detach() ## 추론 시 기울기 흐르는 것 방지 (오류 발생했었음) => 결론적으로 해결 안됐음;;
        
        probabilities = F.softmax(logits, dim=1)
        
        # Extracting probabilities for negative and positive sentiments
        neg, neu, pos = probabilities.unbind(dim=-1)
        sent_score = pos / (pos + neg)
        
        for i, score in enumerate(sent_score):
            if len(batch['entity'][i]) != 0:
                for entity in batch['entity'][i]:
                    #Store the sentiment score in train_ent_sent
                    self.ent_sent[batch['conv_id'][i]][entity] = score

In [54]:
train_ent_sent_model = EntitySentModel(sent_model).to(args.device)
val_ent_sent_model = EntitySentModel(sent_model).to(args.device)
test_ent_sent_model = EntitySentModel(sent_model).to(args.device)

for batch in tqdm(sent_train_dataloader):
    with torch.no_grad():
        train_ent_sent_model(batch)
    
for batch in tqdm(sent_val_dataloader):
    with torch.no_grad():
        val_ent_sent_model.forward(batch)
    
for batch in tqdm(sent_test_dataloader):
    with torch.no_grad():
        test_ent_sent_model.forward(batch)

train_ent_sent = train_ent_sent_model.ent_sent
val_ent_sent = val_ent_sent_model.ent_sent
test_ent_sent = test_ent_sent_model.ent_sent

  0%|          | 0/1722 [00:00<?, ?it/s]



  0%|          | 0/192 [00:00<?, ?it/s]

  0%|          | 0/242 [00:00<?, ?it/s]

In [50]:
list(train_ent_sent.items())

('14205', {26306: tensor(0.9974, device='cuda:1')})

In [49]:
len(train_ent_sent), len(val_ent_sent)

list(train_ent_sent.items())[-50], list(val_ent_sent.items())[-50], list(test_ent_sent.items())[-50]

(('14205', {26306: tensor(0.9974, device='cuda:1')}),
 ('18172',
  {17897: tensor(0.1186, device='cuda:1'),
   18093: tensor(0.9979, device='cuda:1'),
   1925: tensor(0.1186, device='cuda:1'),
   22254: tensor(0.9990, device='cuda:1')}),
 ('23109',
  {30546: tensor(0.9579, device='cuda:1'),
   12533: tensor(0.6316, device='cuda:1'),
   27382: tensor(0.5373, device='cuda:1')}))

In [40]:
train_ent_sent

defaultdict(dict,
            {'19429': {19174: tensor(0.9959, device='cuda:1')},
             '8618': {3682: tensor(0.9924, device='cuda:1'),
              8828: tensor(0.9924, device='cuda:1'),
              12533: tensor(0.2536, device='cuda:1'),
              7822: tensor(0.2536, device='cuda:1'),
              26306: tensor(0.3332, device='cuda:1')},
             '16534': {13573: tensor(0.6659, device='cuda:1'),
              6696: tensor(0.9966, device='cuda:1')},
             '13724': {2115: tensor(0.9983, device='cuda:1'),
              3812: tensor(0.9939, device='cuda:1'),
              20028: tensor(0.9983, device='cuda:1'),
              19174: tensor(0.9983, device='cuda:1'),
              26306: tensor(0.9979, device='cuda:1'),
              5251: tensor(0.9979, device='cuda:1'),
              3828: tensor(0.9979, device='cuda:1'),
              4028: tensor(0.9979, device='cuda:1')},
             '17100': {25417: tensor(0.1408, device='cuda:1'),
              2115: tenso

In [186]:
# Assuming ent_sent is your dictionary
for conv_id, entities in train_ent_sent.items():
    for entity_id, sent_score in entities.items():
        print(f"Conv ID: {conv_id}, Entity ID: {entity_id}, requires_grad: {sent_score.requires_grad}")
        # You can break or continue based on your needs
        break  # This breaks out of the inner loop
    break  # This breaks out of the outer loop

Conv ID: 19429, Entity ID: 19174, requires_grad: False


In [185]:
train_ent_sent

defaultdict(dict,
            {'19429': {19174: tensor(0.9981, device='cuda:1')},
             '8618': {3682: tensor(0.9941, device='cuda:1'),
              8828: tensor(0.9941, device='cuda:1'),
              12533: tensor(0.9917, device='cuda:1'),
              7822: tensor(0.9917, device='cuda:1'),
              26306: tensor(0.9946, device='cuda:1')},
             '16534': {13573: tensor(0.8968, device='cuda:1'),
              6696: tensor(0.9980, device='cuda:1')},
             '13724': {2115: tensor(0.9981, device='cuda:1'),
              3812: tensor(0.9968, device='cuda:1'),
              20028: tensor(0.9981, device='cuda:1'),
              19174: tensor(0.9981, device='cuda:1'),
              26306: tensor(0.9979, device='cuda:1'),
              5251: tensor(0.9979, device='cuda:1'),
              3828: tensor(0.9979, device='cuda:1'),
              4028: tensor(0.9979, device='cuda:1')},
             '17100': {25417: tensor(0.9767, device='cuda:1'),
              2115: tenso

In [76]:
# from transformers import pipeline
# sentiment_analysis = pipeline("sentiment-analysis",model=args.text_encoder)
# print(sentiment_analysis("I love this!"))

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.

[{'label': 'LABEL_0', 'score': 0.5122907161712646}]


In [40]:
# model = AutoModel.from_pretrained("roberta-base").to(args.device)
# senti_tokenizer = AutoTokenizer.from_pretrained("roberta-base")

# text = "I hate this!"

# encoded_input = senti_tokenizer(text, return_tensors='pt')
# for k, v in encoded_input.items():
#     encoded_input[k] = torch.as_tensor(v, device=args.device)
    
# output = model(**encoded_input)
# scores = output[0][0]

# li = nn.Linear(768, 2).to(args.device)
# scores = li(scores)

# sigmoid = nn.Sigmoid()
# scores = sigmoid(scores)


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [54]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

task='sentiment'
MODEL = f"cardiffnlp/twitter-roberta-base-{task}"

# Assuming 'args.device' is set to 'cuda' if a GPU is available, else 'cpu'
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Load the pre-trained RoBERTa model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
model = model.to(device)

# Load the tokenizer
senti_tokenizer = AutoTokenizer.from_pretrained(MODEL)

# Sample text
text = "I like to watch scary movies"

# Tokenize the text
inputs = senti_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Get the logits from the model
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()

# Note: twitter-roberta-senti model => [negative, neutral, positive]
negative, neutral, positive = probabilities[0]
sentiment_score = positive / (positive + negative)

print(sentiment_score)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


0.995256


In [90]:
from transformers import RobertaTokenizer, RobertaModel
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
output[0][0]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[-0.1146,  0.1103, -0.0149,  ..., -0.0809, -0.0018, -0.0271],
        [-0.0225,  0.1612,  0.0556,  ...,  0.5366,  0.1196,  0.1576],
        [ 0.0532, -0.0020,  0.0370,  ..., -0.4887,  0.1641,  0.2736],
        ...,
        [-0.1586,  0.0837,  0.1302,  ...,  0.3970,  0.1715, -0.0848],
        [-0.1065,  0.1044, -0.0383,  ..., -0.1068, -0.0015, -0.0517],
        [ 0.0059,  0.0758,  0.1228,  ...,  0.1037,  0.0075,  0.0976]],
       grad_fn=<SelectBackward0>)

## KG 모델링

In [4]:
kg = DBpedia(dataset=args.dataset).get_entity_kg_info()

In [None]:
class SelfAttentionBatch(nn.Module):
    def __init__(self, dim, da, alpha=0.2, dropout=0.5):
        super(SelfAttentionBatch, self).__init__()
        self.dim = dim
        self.da = da
        self.alpha = alpha
        self.dropout = dropout
        self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
        self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        nn.init.xavier_uniform_(self.b.data, gain=1.414)

    def forward(self, h):
        # h: (N, dim)
        e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1)
        attention = F.softmax(e, dim=0)  # (N)
        return torch.matmul(attention, h)  # (dim)

In [13]:
import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import RGCNConv


class RecModel(nn.Module):
    def __init__(self, kg_emb_dim, n_entity, num_relations,
                 num_bases, edge_index, edge_type, device
    ):
        super(RecModel, self).__init__()
        
        self.device = device
        self.kg_emb_dim = kg_emb_dim
        self.kg_encoder = RGCNConv(kg_emb_dim, kg_emb_dim, num_relations=num_relations,
                                   num_bases=num_bases)
        
        self.edge_index = nn.Parameter(edge_index, requires_grad=False) # edge_index가 entity 정보 및 연결이 담긴 인덱스들 => grad=False로 변하지 말아라 (관계성이니)
        self.edge_type = nn.Parameter(edge_type, requires_grad=False)
    
        self.kg_attn = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim)
        
        self.rec_bias = nn.Linear(self.kg_emb_dim, n_entity)
        # self.rec_loss = nn.CrossEntropyLoss()
    
    def get_user_rep(self, batch):
        user_rep, item_rep = self.get_kg_user_rep(batch)

        return user_rep, item_rep

    def get_kg_user_rep(self, batch):
        user_rep, kg_embedding = self.get_kg_rep(batch) # (bs, user_dim), (n_entities, user_dim)

        return user_rep, kg_embedding # (bs, user_dim), (n_entities, user_dim)
    
    def get_kg_rep(self, batch):
        context_entities = batch['entity']

        kg_user_rep, kg_embedding = self._get_kg_user_rep(context_entities)  # (bs, dim), (n_entity, dim)

        return kg_user_rep, kg_embedding
    
    def _get_kg_user_rep(self, context_entities):
        # ipdb.set_trace()
        kg_embedding = self.kg_encoder(None, self.edge_index, self.edge_type)
        # print(self.config.edge_idx.shape)
        # ipdb.set_trace()
        user_rep = self._encode_user(context_entities, kg_embedding)  # (bs, dim)

        return user_rep, kg_embedding
    
    def _encode_user(self, entity_lists, kg_embedding):
        user_repr_list = []
        for entity_list in entity_lists:
            if not entity_list:
                user_repr_list.append(torch.zeros(self.kg_emb_dim, device=self.device))
                continue
            user_repr = kg_embedding[entity_list]
            user_repr = self.kg_attn(user_repr)
            user_repr_list.append(user_repr)
        return torch.stack(user_repr_list, dim=0)  # (bs, dim)
    
    def forward(self, batch):        
        user_rep, item_rep = self.get_user_rep(batch) # (bs, user_dim), (n_entities, user_dim)
        rec_scores = F.linear(user_rep, item_rep, self.rec_bias.bias)  # (bs, n_entity)

        return rec_scores

In [None]:

rec_loss = nn.CrossEntropyLoss()
rec_loss = self.rec_loss(rec_scores, y)