In [1]:
import math
import random
import os.path as osp

import os
import pandas as pd

import numpy as np

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter

from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader

from tqdm import tqdm
from datetime import datetime

from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import mean_squared_error, ndcg_score, recall_score
from tensorflow.python.keras.preprocessing.sequence import pad_sequences

In [2]:
import utils

In [3]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)

        # one kernel make one feature map
        # kernel 厚度 = channel 數
        self.is_changed = in_channels != out_channels
        self.trans = nn.Conv2d(in_channels, out_channels,
                               kernel_size=1, stride=stride)

    def forward(self, x):
        f_x = self.conv1(x)
        f_x = self.bn1(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv2(f_x)
        f_x = self.bn2(f_x)

        if self.is_changed:
            x = self.trans(x)

        x = f_x + x
        x = self.relu(x)
        return x


class BottleNeck(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, dilation=4):
        super(BottleNeck, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0)
        self.conv2 = nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(
            in_channels=out_channels, out_channels=dilation*out_channels, kernel_size=1, padding=0)

        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        self.bn3 = nn.BatchNorm2d(num_features=out_channels * dilation)

        self.relu = nn.ReLU(inplace=True)

        # one kernel make one feature map
        # kernel 厚度 = channel 數
        self.is_changed = in_channels != (out_channels*dilation)
        self.trans = nn.Conv2d(in_channels, out_channels *
                               dilation, kernel_size=1, stride=stride)

    def forward(self, x):

        f_x = self.conv1(x)
        f_x = self.bn1(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv2(f_x)
        f_x = self.bn2(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv3(f_x)
        f_x = self.bn3(f_x)

        if self.is_changed:
            x = self.trans(x)

        x = f_x + x
        x = self.relu(x)
        return x


class _ResNet(nn.Module):
    def __init__(self, block, block_cnts, dilation=1):
        super(_ResNet, self).__init__()

        self.in_channels = 64
        self.out_channels = 64

        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._layer(
            block, block_cnts[0], dilation, self.in_channels, self.out_channels, stride=1)
        self.layer2 = self._layer(
            block, block_cnts[1], dilation, self.in_channels, self.out_channels, stride=2)
        self.layer3 = self._layer(
            block, block_cnts[2], dilation, self.in_channels, self.out_channels, stride=2)
        self.layer4 = self._layer(
            block, block_cnts[3], dilation, self.in_channels, self.out_channels, stride=2)

        self.avg = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.dense = nn.Linear(in_features=self.in_channels, out_features=6)
        self.activation = nn.Softmax(dim=1)
        self.flatten = nn.Flatten(start_dim=1)

    def _layer(self, block, block_cnt, dilation, in_channels, out_channels, stride):
        # in_channels: param of previous block output channel
        # out_channels: param of current block input channel

        blocks = []
        blocks.append(
            block(in_channels=in_channels,
                  out_channels=out_channels, stride=stride)
        )

        for cnt in range(1, block_cnt):
            b = block(in_channels=dilation * out_channels,
                      out_channels=out_channels)
            blocks.append(b)

        self.in_channels = out_channels * dilation
        self.out_channels = out_channels * 2

        return nn.Sequential(*blocks)

    def forward(self, x):
        x = self.conv1(x)
        x = self.max_pool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg(x)

        x = self.flatten(x)
        x = self.dense(x)

        x = self.activation(x)
        return x

In [5]:
device = 'cpu'
use_cuda = True
if use_cuda and torch.cuda.is_available():
    print('cuda ready...')
    device = 'cuda:1'

root = os.path.join(os.getcwd(), "DoubanBook")
rel_p = os.path.join(root, "user_book.dat")

user_cnt = 16239
item_cnt = 14284

cuda ready...


In [6]:
class Args:
    def __init__(self,
                 user_cnt,
                 item_cnt,
                embed_dim=50,
                inp_drop=0.5,
                feat_drop=0.5,
                hid_drop=0.2,
                perm=1,
                k_w=5,
                k_h=10,
                ker_sz=3):
        self.user_cnt = user_cnt
        self.item_cnt = item_cnt
        self.embed_dim = embed_dim
        self.inp_drop = inp_drop
        self.feat_drop = feat_drop
        self.hid_drop = hid_drop
        self.perm = perm
        self.k_w = k_w
        self.k_h = k_h
        self.ker_sz = ker_sz

## Interact

In [7]:
sparse_features = ["cols_0", "cols_1"] # user_id, movie_id
rating = "cols_2"

In [8]:
rel = utils.read_file(rel_p)
rel.head()

Unnamed: 0,cols_0,cols_1,cols_2
0,10855,938,4
1,10027,3,3
2,741,2426,5
3,453,1263,4
4,11665,7717,5


In [9]:
rel_feat, rel_pos = utils.get_explicit_features(rel_p)

In [10]:
user_cnt, item_cnt = rel_feat.shape

In [28]:
class Dataset(BaseDataset):
    
    def __init__(self, dataframe):
        
        # Read node features
        self.users = dataframe.cols_0.values
        self.items = dataframe.cols_1.values
        self.rats = dataframe.cols_2.values
        
    def __len__(self):
        return len(self.rats)
        
    def __getitem__(self, idx):
        
        user = self.users[idx]
        item = self.items[idx]
        rat = self.rats[idx] if self.rats[idx] >= 3 else 0
        label = np.zeros(6)
        label[rat] = 1
        
        return user, item, rat, label

In [29]:
args = Args(user_cnt, item_cnt)

In [30]:
def get_chequer_perm(perm=1, k_w=30, k_h=50):
    
    embed_dim = k_w * k_h
    ent_perm  = np.int32([np.random.permutation(embed_dim) for _ in range(perm)])
    rel_perm  = np.int32([np.random.permutation(embed_dim) for _ in range(perm)])

    comb_idx = []
    for k in range(perm):
        temp = []
        ent_idx, rel_idx = 0, 0

        for i in range(k_h):
            for j in range(k_w):
                if k % 2 == 0:
                    if i % 2 == 0:
                        temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
                        temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;
                    else:
                        temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;
                        temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
                else:
                    if i % 2 == 0:
                        temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;
                        temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
                    else:
                        temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
                        temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;

        comb_idx.append(temp)

    chequer_perm = torch.LongTensor(np.int32(comb_idx)).to(device)
    return chequer_perm


In [14]:
chequer = get_chequer_perm(perm=args.perm, k_w=args.k_w, k_h=args.k_h)
chequer.shape

torch.Size([1, 100])

In [15]:
chequer

tensor([[ 3, 72, 44, 75, 22, 63, 15, 73, 14, 95, 69, 26, 50, 38, 55, 13, 83, 34,
         84, 29, 43, 60, 23, 98,  6, 61, 36, 85, 27, 65, 81, 40, 77, 17, 89, 49,
         88, 16, 76, 11, 28, 86, 37, 78, 12, 51, 18, 53,  1, 79, 71, 46, 52,  4,
         91, 21, 70, 35, 99, 48, 32, 87, 10, 67,  5, 97, 25, 57, 33, 82, 68, 24,
         62,  7, 94, 30, 59, 41, 92, 39,  9, 90, 45, 93, 42, 66, 47, 54,  0, 74,
         58, 19, 56,  8, 80, 20, 96,  2, 64, 31]], device='cuda:1')

In [16]:
class InteractE(torch.nn.Module):
    """
    Proposed method in the paper. Refer Section 6 of the paper for mode details 
    Parameters
    ----------
    params:         Hyperparameters of the model
    chequer_perm:   Reshaping to be used by the model

    Returns
    -------
    The InteractE model instance

    """
    def __init__(self, params, chequer_perm):
        super(InteractE, self).__init__()

        self.p = params
        self.user_embed = torch.nn.Embedding(self.p.user_cnt+1, self.p.embed_dim); # xavier_normal_(self.ent_embed.weight)
        self.item_embed = torch.nn.Embedding(self.p.item_cnt+1, self.p.embed_dim);
        self.chequer_perm = chequer_perm


        self.resnet = _ResNet(BasicBlock, [2, 2, 2, 2], 1)
    
    def circular_padding_chw(self, batch, padding):
        upper_pad = batch[..., -padding:, :]
        lower_pad = batch[..., :padding, :]
        temp = torch.cat([upper_pad, batch, lower_pad], dim=2)

        left_pad = temp[..., -padding:]
        right_pad = temp[..., :padding]
        padded = torch.cat([left_pad, temp, right_pad], dim=3)
        return padded

    def forward(self, user_idx, item_idx):

        user = self.user_embed(user_idx)
        item = self.item_embed(item_idx)
        
        comb_emb = torch.cat([user, item], dim=-1)
        chequer_perm = comb_emb[:, self.chequer_perm] # batch, 1, embed_size
 
        stack_inp = chequer_perm.reshape((-1, self.p.perm, 2*self.p.k_w, self.p.k_h)) # batch, 1, 2*k_w, k_h
#         print(user_idx, item_idx, chequer_perm.shape, stack_inp.shape)
        #### start
        x = self.circular_padding_chw(stack_inp, self.p.ker_sz//2)
        
        pred = self.resnet(x)

        return pred


In [31]:
def train(model, optim, loader, device):
    
    model.train()
    model = model.to(device)
    
    total_loss = 0
    for index, data in tqdm(enumerate(loader)):
        
        optim.zero_grad()
        
        user, item, label, _ = data
        
        user = user.to(device)
        item = item.to(device)
        label = label.to(device, dtype=torch.long)

        pred = model(user, item) 
        
        loss = torch.nn.CrossEntropyLoss()(pred, label)

        loss.backward()
        optim.step()
        
        total_loss += loss.item()

    total_loss = total_loss/(index+1)
    return total_loss

In [32]:
@torch.no_grad()
def test(model, loader, device="cpu"):
    
    from sklearn.metrics import mean_squared_error, ndcg_score, recall_score
    
    model.eval()
    model = model.to(device)
    
    mse_list = []
    recall_list = []
    ndcg_list = []
    for data in tqdm(loader):
        
        user, item, label, onehot = data
        user = user.to(device)
        item = item.to(device)

        pred = model(user, item) 
        mse = mean_squared_error(onehot, pred.detach().cpu().numpy())
        mse_list.append(mse)
        
        pred_val = torch.argmax(pred, dim=1).cpu().detach().numpy()
        pred_val = np.where(pred_val > 3, 1, 0).reshape(1, -1)
        label = np.where(label > 3, 1, 0).reshape(1, -1)
        
#         print(label, pred_val)
        recall = recall_score(label.reshape(-1, 1), pred_val.reshape(-1, 1))
        recall_list.append(recall)
        ndcg = ndcg_score(label, pred_val)
        ndcg_list.append(ndcg)
     
    return sum(mse_list)/len(mse_list), sum(recall_list)/len(recall_list), sum(ndcg_list)/len(ndcg_list)

In [33]:
@torch.no_grad()
def predict(model, loader, device="cpu"):
    
#     from sklearn.metrics import roc_auc_score
    
    model.eval()
    model = model.to(device)
    
    y_pred = []
    for data in tqdm(loader):
        follower, followee, _, _ = data
        
        user, item, label = data
        user = user.to(device)
        item = item.to(device)
        label = label.to(device)

        pred = model(user, item)
        
        y_pred.append(pred.view(-1).cpu())
     
    return torch.cat(y_pred)

In [34]:
k = 5
kf = KFold(n_splits=5)

In [35]:
fold_cnt = 0

mse_list = []
recall_list = []
ndcg_list = []

for train_index, test_index in kf.split(rel):
    
    fold_cnt += 1
    print("========= Fold: {} ==========".format(fold_cnt))
    
    test_df = rel.iloc[test_index]
    
    train_index, valid_index = train_test_split(train_index, test_size=0.1)
    train_df = rel.iloc[train_index]
    valid_df = rel.iloc[valid_index]
    
    trainset = Dataset(train_df)
    validset = Dataset(valid_df)
    testset = Dataset(test_df)
    
    trainloader = DataLoader(trainset, batch_size=256, shuffle=True, num_workers=2) 
    testloader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2) 
    validloader = DataLoader(validset, batch_size=256, shuffle=False, num_workers=2) 
    
    model = InteractE(args, chequer).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    
    for epoch in range(10):
        print("======= Epoch {} ========".format(epoch+1))
        loss = train(model, optimizer, trainloader, device)
        mse, recall, ndcg = test(model, trainloader, device)
        print("Trainset: mse={:.5f}, recall={:.5f}, ndcg={:.5f}".format(mse, recall, ndcg))
        mse, recall, ndcg = test(model, validloader, device)
        print("Validset: mse={:.5f}, recall={:.5f}, ndcg={:.5f}".format(mse, recall, ndcg))
        
    
    mse, recall, ndcg = test(model, testloader, device)
    mse_list.append(mse)
    recall_list.append(recall)
    ndcg_list.append(ndcg)



2228it [01:52, 19.85it/s]
100%|██████████| 2228/2228 [00:23<00:00, 96.11it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14666, recall=1.00000, ndcg=0.93355


100%|██████████| 248/248 [00:02<00:00, 90.16it/s]

Validset: mse=0.14725, recall=1.00000, ndcg=0.93361



2228it [01:51, 19.96it/s]
100%|██████████| 2228/2228 [00:23<00:00, 96.85it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13422, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 89.39it/s]


Validset: mse=0.13605, recall=1.00000, ndcg=0.93361


2228it [01:51, 20.03it/s]
100%|██████████| 2228/2228 [00:23<00:00, 96.58it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14075, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 89.84it/s]

Validset: mse=0.14460, recall=1.00000, ndcg=0.93361



2228it [01:51, 19.91it/s]
100%|██████████| 2228/2228 [00:23<00:00, 93.84it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14269, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:03<00:00, 81.82it/s]

Validset: mse=0.14709, recall=1.00000, ndcg=0.93361



2228it [01:54, 19.53it/s]
100%|██████████| 2228/2228 [00:26<00:00, 83.02it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13999, recall=1.00000, ndcg=0.93355


100%|██████████| 248/248 [00:03<00:00, 77.13it/s]

Validset: mse=0.14556, recall=1.00000, ndcg=0.93361



2228it [01:54, 19.49it/s]
100%|██████████| 2228/2228 [00:25<00:00, 87.45it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14035, recall=1.00000, ndcg=0.93355


100%|██████████| 248/248 [00:02<00:00, 83.05it/s]

Validset: mse=0.14745, recall=1.00000, ndcg=0.93361



2228it [01:51, 19.92it/s]
100%|██████████| 2228/2228 [00:26<00:00, 84.71it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14303, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:03<00:00, 75.35it/s]

Validset: mse=0.15025, recall=1.00000, ndcg=0.93361



2228it [01:54, 19.54it/s]
100%|██████████| 2228/2228 [00:27<00:00, 80.95it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15531, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:03<00:00, 75.53it/s]

Validset: mse=0.16207, recall=1.00000, ndcg=0.93361



2228it [01:53, 19.66it/s]
100%|██████████| 2228/2228 [00:26<00:00, 83.02it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14797, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:03<00:00, 79.91it/s]

Validset: mse=0.15686, recall=1.00000, ndcg=0.93361



2228it [01:54, 19.52it/s]
100%|██████████| 2228/2228 [00:26<00:00, 84.54it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14559, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:03<00:00, 77.27it/s]
  0%|          | 0/619 [00:00<?, ?it/s]

Validset: mse=0.15639, recall=1.00000, ndcg=0.93361


100%|██████████| 619/619 [00:07<00:00, 81.77it/s]




2228it [01:53, 19.71it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.69it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14336, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 82.71it/s]

Validset: mse=0.14490, recall=1.00000, ndcg=0.93334



2228it [01:56, 19.18it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.51it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13881, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 86.66it/s]

Validset: mse=0.14160, recall=1.00000, ndcg=0.93334



2228it [01:55, 19.30it/s]
100%|██████████| 2228/2228 [00:24<00:00, 90.03it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13448, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 84.93it/s]

Validset: mse=0.13785, recall=1.00000, ndcg=0.93334



2228it [01:52, 19.77it/s]
100%|██████████| 2228/2228 [00:23<00:00, 93.38it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15747, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 85.55it/s]

Validset: mse=0.16132, recall=1.00000, ndcg=0.93334



2228it [01:55, 19.36it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.08it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15030, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 84.51it/s]

Validset: mse=0.15525, recall=1.00000, ndcg=0.93334



2228it [01:55, 19.23it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.36it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14747, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 84.32it/s]

Validset: mse=0.15422, recall=1.00000, ndcg=0.93334



2228it [01:54, 19.45it/s]
100%|██████████| 2228/2228 [00:23<00:00, 94.23it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14215, recall=1.00000, ndcg=0.93360


100%|██████████| 248/248 [00:02<00:00, 86.40it/s]

Validset: mse=0.15002, recall=1.00000, ndcg=0.93334



2228it [01:56, 19.16it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.00it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14192, recall=1.00000, ndcg=0.93358


100%|██████████| 248/248 [00:02<00:00, 87.75it/s]

Validset: mse=0.15112, recall=1.00000, ndcg=0.93334



2228it [01:52, 19.74it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.36it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14850, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 87.69it/s]

Validset: mse=0.15797, recall=1.00000, ndcg=0.93334



2228it [01:55, 19.32it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.13it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14457, recall=1.00000, ndcg=0.93359


100%|██████████| 248/248 [00:02<00:00, 84.90it/s]
  0%|          | 0/619 [00:00<?, ?it/s]

Validset: mse=0.15594, recall=1.00000, ndcg=0.93334


100%|██████████| 619/619 [00:06<00:00, 91.83it/s]




2228it [01:54, 19.38it/s]
100%|██████████| 2228/2228 [00:24<00:00, 90.20it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.12716, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 87.37it/s]


Validset: mse=0.12735, recall=1.00000, ndcg=0.93407


2228it [01:55, 19.32it/s]
100%|██████████| 2228/2228 [00:24<00:00, 89.18it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13365, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 86.89it/s]

Validset: mse=0.13544, recall=1.00000, ndcg=0.93407



2228it [01:55, 19.34it/s]
100%|██████████| 2228/2228 [00:23<00:00, 93.90it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14533, recall=1.00000, ndcg=0.93355


100%|██████████| 248/248 [00:02<00:00, 88.20it/s]

Validset: mse=0.14725, recall=1.00000, ndcg=0.93407



2228it [01:53, 19.68it/s]
100%|██████████| 2228/2228 [00:24<00:00, 90.06it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15043, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 86.79it/s]

Validset: mse=0.15253, recall=1.00000, ndcg=0.93407



2228it [01:53, 19.55it/s]
100%|██████████| 2228/2228 [00:23<00:00, 96.68it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14638, recall=1.00000, ndcg=0.93355


100%|██████████| 248/248 [00:02<00:00, 89.34it/s]

Validset: mse=0.15176, recall=1.00000, ndcg=0.93407



2228it [01:54, 19.52it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.91it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.16555, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:03<00:00, 82.43it/s]

Validset: mse=0.16813, recall=1.00000, ndcg=0.93407



2228it [01:55, 19.22it/s]
100%|██████████| 2228/2228 [00:23<00:00, 94.60it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14018, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 83.52it/s]

Validset: mse=0.14764, recall=1.00000, ndcg=0.93407



2228it [01:55, 19.34it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.45it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15236, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 86.64it/s]

Validset: mse=0.15834, recall=1.00000, ndcg=0.93407



2228it [01:53, 19.66it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.09it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14754, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 85.13it/s]

Validset: mse=0.15634, recall=1.00000, ndcg=0.93407



2228it [01:50, 20.22it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.41it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14808, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 83.74it/s]
  0%|          | 0/619 [00:00<?, ?it/s]

Validset: mse=0.15749, recall=1.00000, ndcg=0.93407


100%|██████████| 619/619 [00:06<00:00, 90.27it/s]




2228it [01:55, 19.37it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.84it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14152, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 87.87it/s]

Validset: mse=0.14242, recall=1.00000, ndcg=0.93364



2228it [01:53, 19.66it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.60it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13835, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:03<00:00, 82.51it/s]

Validset: mse=0.14057, recall=1.00000, ndcg=0.93364



2228it [01:53, 19.68it/s]
100%|██████████| 2228/2228 [00:23<00:00, 93.20it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13922, recall=1.00000, ndcg=0.93356


100%|██████████| 248/248 [00:02<00:00, 86.93it/s]

Validset: mse=0.14217, recall=1.00000, ndcg=0.93364



2228it [01:55, 19.37it/s]
100%|██████████| 2228/2228 [00:24<00:00, 89.68it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15410, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:03<00:00, 82.58it/s]

Validset: mse=0.15702, recall=1.00000, ndcg=0.93364



2228it [01:54, 19.49it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.38it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14793, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:03<00:00, 82.37it/s]

Validset: mse=0.15209, recall=1.00000, ndcg=0.93364



2228it [01:55, 19.26it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.97it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15185, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:02<00:00, 83.07it/s]

Validset: mse=0.15655, recall=1.00000, ndcg=0.93364



2228it [01:55, 19.29it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.77it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14366, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:02<00:00, 84.14it/s]


Validset: mse=0.15017, recall=1.00000, ndcg=0.93364


2228it [01:56, 19.12it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.57it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14259, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:02<00:00, 83.85it/s]

Validset: mse=0.15060, recall=1.00000, ndcg=0.93364



2228it [01:55, 19.31it/s]
100%|██████████| 2228/2228 [00:24<00:00, 92.62it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14656, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:02<00:00, 85.13it/s]


Validset: mse=0.15483, recall=1.00000, ndcg=0.93364


2228it [01:55, 19.21it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.92it/s] 
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.15254, recall=1.00000, ndcg=0.93357


100%|██████████| 248/248 [00:02<00:00, 93.03it/s] 
  0%|          | 0/619 [00:00<?, ?it/s]

Validset: mse=0.16184, recall=1.00000, ndcg=0.93364


100%|██████████| 619/619 [00:06<00:00, 98.04it/s] 




2228it [01:55, 19.37it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.40it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13432, recall=1.00000, ndcg=0.93364


100%|██████████| 248/248 [00:02<00:00, 85.41it/s]

Validset: mse=0.13592, recall=1.00000, ndcg=0.93263



2228it [01:55, 19.30it/s]
100%|██████████| 2228/2228 [00:24<00:00, 89.13it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.13576, recall=1.00000, ndcg=0.93364


100%|██████████| 248/248 [00:02<00:00, 84.92it/s]

Validset: mse=0.13843, recall=1.00000, ndcg=0.93263



2228it [01:54, 19.47it/s]
100%|██████████| 2228/2228 [00:25<00:00, 88.20it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.12990, recall=1.00000, ndcg=0.93364


100%|██████████| 248/248 [00:02<00:00, 85.67it/s]

Validset: mse=0.13317, recall=1.00000, ndcg=0.93263



2228it [01:54, 19.38it/s]
100%|██████████| 2228/2228 [00:25<00:00, 88.78it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14130, recall=1.00000, ndcg=0.93363


100%|██████████| 248/248 [00:02<00:00, 83.61it/s]

Validset: mse=0.14656, recall=1.00000, ndcg=0.93263



2228it [01:54, 19.45it/s]
100%|██████████| 2228/2228 [00:25<00:00, 88.68it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14030, recall=1.00000, ndcg=0.93363


100%|██████████| 248/248 [00:02<00:00, 83.29it/s]

Validset: mse=0.14640, recall=1.00000, ndcg=0.93263



2228it [01:53, 19.62it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.02it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14415, recall=1.00000, ndcg=0.93364


100%|██████████| 248/248 [00:02<00:00, 85.10it/s]

Validset: mse=0.15235, recall=1.00000, ndcg=0.93263



2228it [01:54, 19.42it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.89it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14110, recall=1.00000, ndcg=0.93363


100%|██████████| 248/248 [00:02<00:00, 86.23it/s]

Validset: mse=0.15014, recall=1.00000, ndcg=0.93263



2228it [01:54, 19.48it/s]
100%|██████████| 2228/2228 [00:24<00:00, 91.15it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14916, recall=1.00000, ndcg=0.93364


100%|██████████| 248/248 [00:02<00:00, 86.02it/s]

Validset: mse=0.15891, recall=1.00000, ndcg=0.93263



2228it [01:56, 19.20it/s]
100%|██████████| 2228/2228 [00:24<00:00, 90.15it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14836, recall=1.00000, ndcg=0.93363


100%|██████████| 248/248 [00:02<00:00, 84.19it/s]

Validset: mse=0.15920, recall=1.00000, ndcg=0.93263



2228it [01:52, 19.77it/s]
100%|██████████| 2228/2228 [00:23<00:00, 93.08it/s]
  0%|          | 0/248 [00:00<?, ?it/s]

Trainset: mse=0.14967, recall=1.00000, ndcg=0.93363


100%|██████████| 248/248 [00:02<00:00, 87.43it/s]
  0%|          | 0/619 [00:00<?, ?it/s]

Validset: mse=0.16065, recall=1.00000, ndcg=0.93263


100%|██████████| 619/619 [00:06<00:00, 91.74it/s]


### douban without filter

In [22]:
sum(mse_list)/len(mse_list)

0.15536124332294682

In [23]:
sum(recall_list)/len(recall_list)

1.0

In [24]:
sum(ndcg_list)/len(ndcg_list)

0.9335744432272624

### douban with filter

In [36]:
sum(mse_list)/len(mse_list)

0.1586426511035266

In [37]:
sum(recall_list)/len(recall_list)

1.0

In [38]:
sum(ndcg_list)/len(ndcg_list)

0.9335744432272624

In [None]:
valid_auc = 0
for epoch in range(100):

    loss = train(model, optimizer, trainloader, device)
    train_auc = test(model, trainloader, device)
    auc = test(model, validloader, device)
   
    if auc > valid_auc:
        valid_auc = auc
        
        checkpoint = {
            'model_stat': model.state_dict(),
            'optimizer_stat': optimizer.state_dict(),
        }
        
        torch.save(checkpoint, os.path.join(root, "results", dataset, "interactE", 
                                                "{}_{}.pth".format(date_time, dataset)))
        print("\nSave Model")
    
    print("Epoch: {}, loss={:.5f}, valid_auc={:.5f}, train_auc={:.5f}".format(epoch+1, loss, auc, train_auc))
    

In [89]:
model, optimizer = load_checkpoint(os.path.join(root, "results", dataset, "interactE", "weight.pth"), device=device)

Embedding(877, 97)
pretrained finded


In [90]:
pred = predict(model, testloader, device)

100%|██████████| 21/21 [00:00<00:00, 61.72it/s]


In [91]:
pred = pred.detach().cpu().numpy()
# pred = np.round(pred, 3)
pred

array([9.99296665e-01, 9.99974608e-01, 9.99904752e-01, 9.81246769e-01,
       9.41093604e-05, 3.11104268e-05, 9.99999166e-01, 2.24112961e-02,
       9.99894381e-01, 9.99646306e-01, 4.32964180e-06, 9.97670829e-01,
       2.91336537e-03, 9.98132527e-01, 4.62299140e-05, 9.99988914e-01,
       1.76502326e-07, 9.99999762e-01, 1.65335496e-03, 9.99194682e-01,
       1.00000000e+00, 9.99995828e-01, 9.99987245e-01, 3.67666507e-04,
       2.47902626e-05, 9.99567568e-01, 9.99998808e-01, 2.89711033e-06,
       9.99994516e-01, 9.99999881e-01, 9.98152077e-01, 9.99993324e-01,
       4.22160685e-01, 9.99997497e-01, 9.99997377e-01, 9.99977469e-01,
       1.86935576e-05, 9.99999762e-01, 9.99999881e-01, 9.99915004e-01,
       9.99989629e-01, 9.99951839e-01, 9.97748911e-01, 8.24332631e-07,
       7.14886427e-01, 5.42289818e-05, 9.99966025e-01, 3.76991329e-06,
       9.99999642e-01, 2.09501650e-05, 9.99803841e-01, 5.23268938e-01,
       3.84405375e-01, 3.82165535e-06, 9.99978065e-01, 8.77665798e-07,
      

In [92]:
test_df["prob"] = pred
test_df.to_csv("conv_all.csv", index=False)

In [93]:
upload = test_df[["id", "prob"]]
upload.to_csv("upload.csv", index=False)