In [1]:
import pickle
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchtext.legacy import data
import dgl
import tqdm

import layers
import sampler as sampler_module
import evaluation

import optuna

Добавление механизма внимания вместо весов, полученных на основе случайного блуждания

In [2]:
import torch.nn.functional as F

class AttentionLayer(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims):
        super(AttentionLayer, self).__init__()

        self.Q = nn.Linear(input_dims, hidden_dims)
        self.ATTN = nn.Linear(2 * hidden_dims, 1)
        self.W = nn.Linear(input_dims + hidden_dims, output_dims)
        self.reset_parameters()
        self.dropout = nn.Dropout(0.5)
        
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_uniform_(self.Q.weight, gain=gain)
        nn.init.xavier_uniform_(self.W.weight, gain=gain)
        nn.init.xavier_normal_(self.ATTN.weight, gain=gain)
        nn.init.constant_(self.Q.bias, 0)
        nn.init.constant_(self.W.bias, 0)
        nn.init.constant_(self.ATTN.bias, 0)
        
    def attention_score(self, edges):
        a = self.ATTN(torch.cat([edges.src['n'], edges.dst['n']], dim=1))
        return {'e': F.leaky_relu(a)}
    
    def message_func(self, edges):
        return {'n': edges.src['n'], 'e': edges.data['e']}
    
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['n'], dim=1)
        return {'h': h}
    
    def forward(self, block, h):
        h_src, h_dst = h
        with block.local_scope():
            # add n_u
            block.srcdata['n'] = self.Q(self.dropout(h_src))
            block.dstdata['n'] = block.srcdata['n'][:len(block.dstdata[dgl.NID])]
            
            # add aggregation (attention) n_u
            block.apply_edges(self.attention_score)

            block.update_all(self.message_func, self.reduce_func)
            
            z = F.relu(self.W(self.dropout(torch.cat([block.dstdata['h'], h_dst], 1))))
            z_norm = z.norm(2, 1, keepdim=True)
            z_norm = torch.where(z_norm == 0, torch.tensor(1.).to(z_norm), z_norm)
            z = z / z_norm
            return z
        
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims, num_heads, merge='cat'):
        super(MultiHeadAttentionLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(AttentionLayer(input_dims, hidden_dims, output_dims))
        self.merge = merge

    def forward(self, block, h):
        z_items = []
        for head in self.heads:
            z_item = head(block, h)
            z_items.append(z_item)
    
        if self.merge == 'cat':
            return torch.cat(z_items, dim=1)
        else:
            return torch.mean(torch.stack(z_items), 0)
       
    
class Net(nn.Module):
    def __init__(self, hidden_dims, num_heads, num_layers):
        super(Net, self).__init__()
        
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(MultiHeadAttentionLayer(hidden_dims, hidden_dims, hidden_dims, num_heads=2, merge='mean'))

    def forward(self, blocks, h):
        for layer, block in zip(self.convs, blocks):
            h_dst = h[:block.number_of_nodes('DST/' + block.ntypes[0])]
            h = layer(block, (h, h_dst))     
        return h

In [3]:
class PinSAGEModel(nn.Module):
    def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers, n_heads):
        super().__init__()

        self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims)
        self.net = Net(hidden_dims, n_heads, n_layers)
        self.scorer = layers.ItemToItemScorer(full_graph, ntype)

    def forward(self, pos_graph, neg_graph, blocks):
        h_item = self.get_repr(blocks)
        pos_score = self.scorer(pos_graph, h_item)
        neg_score = self.scorer(neg_graph, h_item)
        return (neg_score - pos_score + 1).clamp(min=0)

    def get_repr(self, blocks):
        h_item = self.proj(blocks[0].srcdata)
        h_item_dst = self.proj(blocks[-1].dstdata)
        return h_item_dst + self.net(blocks, h_item)

In [4]:
from dataclasses import dataclass

@dataclass
class TrainArgs:
    output_model_path: str
    random_walk_length: int = 2
    random_walk_restart_prob: float = 0.5
    num_random_walks: int = 10
    num_neighbors: int = 5
    num_layers: int = 2
    num_heads: int = 2
    hidden_dims: int = 16
    batch_size: int = 64
    device: str = 'cpu'
    num_epochs: int = 1
    batches_per_epoch: int = 20000
    num_workers: int = 0
    lr: float = 3e-5
    k: int = 10
    n_latest_items: int = 10
        

ML

In [5]:
with open('data/data_ml.pkl', 'rb') as f:
    dataset = pickle.load(f)
    
args = TrainArgs(output_model_path='models_attention/model_ml_optuna', num_epochs=5, 
                 hidden_dims=64, batches_per_epoch=10000, device='cuda', k=20)


g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']
device = torch.device(args.device)
# Assign user and movie IDs and use them as features (to learn an individual trainable
# embedding for each entity)
g.nodes[user_ntype].data['id'] = torch.arange(g.number_of_nodes(user_ntype))
g.nodes[item_ntype].data['id'] = torch.arange(g.number_of_nodes(item_ntype))
# Prepare torchtext dataset and vocabulary
if item_texts is not None:
    fields = {}
    examples = []
    for key, texts in item_texts.items():
        fields[key] = data.Field(include_lengths=True, lower=True, batch_first=True)
    for i in range(g.number_of_nodes(item_ntype)):
        example = data.Example.fromlist(
            [item_texts[key][i] for key in item_texts.keys()],
            [(key, fields[key]) for key in item_texts.keys()])
        examples.append(example)
    textset = data.Dataset(examples, fields)
    for key, field in fields.items():
        field.build_vocab(getattr(textset, key))
        #field.build_vocab(getattr(textset, key), vectors='fasttext.simple.300d')
else:
    textset = None
# Sampler

In [None]:
def objective(trial):

    # 2. Suggest values of the hyperparameters using a trial object.
    n_layers = trial.suggest_int('n_layers', 1, 3)
    n_heads = trial.suggest_int('n_heads', 1, 3)
    hidden_dims = trial.suggest_int('hidden_dims', 32, 128)
    #num_epochs = trial.suggest_int('num_epochs', 3, 10)
    learning_rate = trial.suggest_float("learning_rate_init", 1e-5, 1e-3)
    num_neighbors = trial.suggest_int('num_neighbors', 1, 15)
    
    batch_sampler = sampler_module.ItemToItemBatchSampler(
        g, user_ntype, item_ntype, args.batch_size)
    neighbor_sampler = sampler_module.NeighborSampler(
        g, user_ntype, item_ntype, args.random_walk_length,
        args.random_walk_restart_prob, args.num_random_walks, num_neighbors,
        n_layers)
    collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
        num_workers=args.num_workers)
    dataloader_test = DataLoader(
        torch.arange(g.number_of_nodes(item_ntype)),
        batch_size=args.batch_size,
        collate_fn=collator.collate_test,
        num_workers=args.num_workers)
    dataloader_it = iter(dataloader)
    
    model = PinSAGEModel(g, item_ntype, textset, hidden_dims, n_layers, n_heads).to(args.device)
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    layers = []

    for epoch_id in range(2):
        model.train()
        for batch_id in tqdm.trange(args.batches_per_epoch):
            pos_graph, neg_graph, blocks = next(dataloader_it)
            # Copy to GPU
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)

            loss = model(pos_graph, neg_graph, blocks).mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            
        model.eval()
        with torch.no_grad():
            item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size)
            h_item_batches = []
            for blocks in dataloader_test:
                for i in range(len(blocks)):
                    blocks[i] = blocks[i].to(device)

                h_item_batches.append(model.get_repr(blocks))
            h_item = torch.cat(h_item_batches, 0)
            metrics = evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size, args.n_latest_items)
            
    return metrics[0][2]

# 3. Create a study object and optimize the objective function.
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=1, timeout=10)

[32m[I 2022-05-29 15:00:47,293][0m A new study created in memory with name: no-name-f634da57-13b1-45fa-9715-5cae64abaeba[0m
 27%|██████████████████████████████████▊                                                                                               | 2677/10000 [06:33<26:54,  4.54it/s]

In [15]:
args = TrainArgs(output_model_path='models_attention/model_ml_optuna', num_epochs=5, 
                 hidden_dims=69, batches_per_epoch=10000, device='cuda', k=20,
                 lr = 3.65e-05, num_neighbors = 9, num_heads=3, num_layers=3)

In [18]:
batch_sampler = sampler_module.ItemToItemBatchSampler(
    g, user_ntype, item_ntype, args.batch_size)
neighbor_sampler = sampler_module.NeighborSampler(
    g, user_ntype, item_ntype, args.random_walk_length,
    args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors,
    args.num_layers)
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
dataloader = DataLoader(
    batch_sampler,
    collate_fn=collator.collate_train,
    num_workers=args.num_workers)
dataloader_test = DataLoader(
    torch.arange(g.number_of_nodes(item_ntype)),
    batch_size=args.batch_size,
    collate_fn=collator.collate_test,
    num_workers=args.num_workers)
dataloader_it = iter(dataloader)
# Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers, args.num_heads).to(device)
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr)

base_ndcg = 0
# For each batch of head-tail-negative triplets...
for epoch_id in range(args.num_epochs):
    model.train()
    for batch_id in tqdm.trange(args.batches_per_epoch):
        pos_graph, neg_graph, blocks = next(dataloader_it)
        # Copy to GPU
        for i in range(len(blocks)):
            blocks[i] = blocks[i].to(device)
        pos_graph = pos_graph.to(device)
        neg_graph = neg_graph.to(device)
        loss = model(pos_graph, neg_graph, blocks).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    #with open(args.output_model_path, 'wb') as f:
        #pickle.dump(model, f)
        
    # Evaluate
    model.eval()
    with torch.no_grad():
        item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size)
        h_item_batches = []
        for blocks in dataloader_test:
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            h_item_batches.append(model.get_repr(blocks))
        h_item = torch.cat(h_item_batches, 0)
        metrics = evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size, args.n_latest_items)
        if metrics[0][2] - base_ndcg > -0.001:
            old_metrics = metrics
            base_ndcg = metrics[0][2]
            torch.save(model.state_dict(), args.output_model_path)
        
            print('Rec by latest item', metrics[0],
                  '\n',
                  'Rec by N latest items', metrics[1])
        else:
            print('Early stopping')
            break

print('RESULT')
for k in [1, 5, 10, 15, 20]:
    metrics = evaluation.evaluate_nn(dataset, h_item, k, args.batch_size, args.n_latest_items)
    
    print(f'Epoch {epoch_id}, Rec by latest item: PR@{k}: {metrics[0][0]}|REC@{k}: {metrics[0][1]}|NDCG@{k}: {metrics[0][2]}')
    print(f'Epoch {epoch_id}, Rec by {args.n_latest_items} latest items: PR@{k}: {metrics[1][0]}|REC@{k}: {metrics[1][1]}|NDCG@{k}: {metrics[1][2]}')
        

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [10:26<00:00, 15.97it/s]


Rec by latest item (0.07684602649006848, 0.06444941453818276, 0.22953112314388738) 
 Rec by N latest items (0.08195364238410799, 0.06713630699920319, 0.23830760938249615)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [10:43<00:00, 15.54it/s]


Rec by latest item (0.10547185430463589, 0.09508574268523141, 0.3649909037398669) 
 Rec by N latest items (0.11462748344370854, 0.10300242999089733, 0.37943280894274156)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [10:28<00:00, 15.91it/s]


Rec by latest item (0.1127649006622514, 0.10242504343950844, 0.4025777391102182) 
 Rec by N latest items (0.11887417218543024, 0.10662430712087431, 0.3996387766863449)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [10:25<00:00, 15.98it/s]


Early stopping
RESULT
Epoch 3, Rec by latest item: PR@1: 0.15281456953642383|REC@1: 0.007639486679528726|NDCG@1: 0.02384105960264901
Epoch 3, Rec by 10 latest items: PR@1: 0.16026490066225166|REC@1: 0.007520462346761501|NDCG@1: 0.02185430463576159
Epoch 3, Rec by latest item: PR@5: 0.14172185430463838|REC@5: 0.0345450778146431|NDCG@5: 0.14461230349135962
Epoch 3, Rec by 10 latest items: PR@5: 0.14450331125828073|REC@5: 0.03417358236809236|NDCG@5: 0.1509793804480531
Epoch 3, Rec by latest item: PR@10: 0.13064569536424264|REC@10: 0.06187579650079092|NDCG@10: 0.25798118762065086
Epoch 3, Rec by 10 latest items: PR@10: 0.13369205298013645|REC@10: 0.062318556255081245|NDCG@10: 0.2681922757820343
Epoch 3, Rec by latest item: PR@15: 0.12113686534216622|REC@15: 0.08356225039798802|NDCG@15: 0.3426607270048519
Epoch 3, Rec by 10 latest items: PR@15: 0.1263024282560739|REC@15: 0.08598403448166919|NDCG@15: 0.3409933649465797
Epoch 3, Rec by latest item: PR@20: 0.11502483443708561|REC@20: 0.1037462

Ta Feng

In [19]:
with open('data/tafeng.pkl', 'rb') as f:
    dataset = pickle.load(f)
    
args = TrainArgs(output_model_path='models_attention/model_tafeng_optuna', num_epochs=5, 
                 hidden_dims=64, batches_per_epoch=10000, device='cuda', k=20)


g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']
device = torch.device(args.device)
# Assign user and movie IDs and use them as features (to learn an individual trainable
# embedding for each entity)
g.nodes[user_ntype].data['id'] = torch.arange(g.number_of_nodes(user_ntype))
g.nodes[item_ntype].data['id'] = torch.arange(g.number_of_nodes(item_ntype))
# Prepare torchtext dataset and vocabulary
if item_texts is not None:
    fields = {}
    examples = []
    for key, texts in item_texts.items():
        fields[key] = data.Field(include_lengths=True, lower=True, batch_first=True)
    for i in range(g.number_of_nodes(item_ntype)):
        example = data.Example.fromlist(
            [item_texts[key][i] for key in item_texts.keys()],
            [(key, fields[key]) for key in item_texts.keys()])
        examples.append(example)
    textset = data.Dataset(examples, fields)
    for key, field in fields.items():
        field.build_vocab(getattr(textset, key))
        #field.build_vocab(getattr(textset, key), vectors='fasttext.simple.300d')
else:
    textset = None
# Sampler

In [20]:
def objective(trial):

    # 2. Suggest values of the hyperparameters using a trial object.
    n_layers = trial.suggest_int('n_layers', 1, 4)
    n_heads = trial.suggest_int('n_heads', 1, 3)
    hidden_dims = trial.suggest_int('hidden_dims', 16, 128)
    num_epochs = trial.suggest_int('num_epochs', 3, 10)
    learning_rate = trial.suggest_float("learning_rate_init", 1e-5, 1e-3)
    num_neighbors = trial.suggest_int('num_neighbors', 1, 15)
    
    batch_sampler = sampler_module.ItemToItemBatchSampler(
        g, user_ntype, item_ntype, args.batch_size)
    neighbor_sampler = sampler_module.NeighborSampler(
        g, user_ntype, item_ntype, args.random_walk_length,
        args.random_walk_restart_prob, args.num_random_walks, num_neighbors,
        n_layers)
    collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
        num_workers=args.num_workers)
    dataloader_test = DataLoader(
        torch.arange(g.number_of_nodes(item_ntype)),
        batch_size=args.batch_size,
        collate_fn=collator.collate_test,
        num_workers=args.num_workers)
    dataloader_it = iter(dataloader)
    
    model = PinSAGEModel(g, item_ntype, textset, hidden_dims, n_layers, n_heads).to(args.device)
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    layers = []

    for epoch_id in range(num_epochs):
        model.train()
        for batch_id in tqdm.trange(args.batches_per_epoch):
            pos_graph, neg_graph, blocks = next(dataloader_it)
            # Copy to GPU
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)

            loss = model(pos_graph, neg_graph, blocks).mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            
        model.eval()
        with torch.no_grad():
            item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size)
            h_item_batches = []
            for blocks in dataloader_test:
                for i in range(len(blocks)):
                    blocks[i] = blocks[i].to(device)

                h_item_batches.append(model.get_repr(blocks))
            h_item = torch.cat(h_item_batches, 0)
            metrics = evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size, args.n_latest_items)
            
    return metrics[0][2]

# 3. Create a study object and optimize the objective function.
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=1, timeout=10)

[32m[I 2022-05-27 12:56:04,733][0m A new study created in memory with name: no-name-d3513a46-090e-4017-91e6-59cb90c2a72b[0m
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [08:57<00:00, 18.61it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [08:54<00:00, 18.69it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [08:58<00:00, 18.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [08:57<00:00, 18.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [08:59<00:00, 18.55it/s]
100%|███

In [21]:
args = TrainArgs(output_model_path='models_attention/model_tafeng_optuna', num_epochs=7, 
                 hidden_dims=41, batches_per_epoch=10000, device='cuda', k=20,
                 lr = 0.0001705, num_neighbors = 3, num_heads=3, num_layers=3)

In [22]:
batch_sampler = sampler_module.ItemToItemBatchSampler(
    g, user_ntype, item_ntype, args.batch_size)
neighbor_sampler = sampler_module.NeighborSampler(
    g, user_ntype, item_ntype, args.random_walk_length,
    args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors,
    args.num_layers)
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
dataloader = DataLoader(
    batch_sampler,
    collate_fn=collator.collate_train,
    num_workers=args.num_workers)
dataloader_test = DataLoader(
    torch.arange(g.number_of_nodes(item_ntype)),
    batch_size=args.batch_size,
    collate_fn=collator.collate_test,
    num_workers=args.num_workers)
dataloader_it = iter(dataloader)
# Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers, args.num_heads).to(device)
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr)

base_ndcg = 0
# For each batch of head-tail-negative triplets...
for epoch_id in range(args.num_epochs):
    model.train()
    for batch_id in tqdm.trange(args.batches_per_epoch):
        pos_graph, neg_graph, blocks = next(dataloader_it)
        # Copy to GPU
        for i in range(len(blocks)):
            blocks[i] = blocks[i].to(device)
        pos_graph = pos_graph.to(device)
        neg_graph = neg_graph.to(device)
        loss = model(pos_graph, neg_graph, blocks).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    #with open(args.output_model_path, 'wb') as f:
        #pickle.dump(model, f)
        
    # Evaluate
    model.eval()
    with torch.no_grad():
        item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size)
        h_item_batches = []
        for blocks in dataloader_test:
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            h_item_batches.append(model.get_repr(blocks))
        h_item = torch.cat(h_item_batches, 0)
        metrics = evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size, args.n_latest_items)
        if metrics[0][2] - base_ndcg > -0.001:
            old_metrics = metrics
            base_ndcg = metrics[0][2]
            torch.save(model.state_dict(), args.output_model_path)
        
            print('Rec by latest item', metrics[0],
                  '\n',
                  'Rec by N latest items', metrics[1])
        else:
            print('Early stopping')
            break

print('RESULT')
for k in [1, 5, 10, 15, 20]:
    metrics = evaluation.evaluate_nn(dataset, h_item, k, args.batch_size, args.n_latest_items)
    
    print(f'Epoch {epoch_id}, Rec by latest item: PR@{k}: {metrics[0][0]}|REC@{k}: {metrics[0][1]}|NDCG@{k}: {metrics[0][2]}')
    print(f'Epoch {epoch_id}, Rec by {args.n_latest_items} latest items: PR@{k}: {metrics[1][0]}|REC@{k}: {metrics[1][1]}|NDCG@{k}: {metrics[1][2]}')
        

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:03<00:00, 18.41it/s]


Rec by latest item (0.009947594273345975, 0.043934115004317265, 0.012954903382886821) 
 Rec by N latest items (0.00996468309725514, 0.04404577646436295, 0.012783288988530275)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:01<00:00, 18.46it/s]


Rec by latest item (0.01037291611286301, 0.04465013060431617, 0.013723015482117278) 
 Rec by N latest items (0.010323548399347642, 0.044249817833571424, 0.013709238430334465)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:04<00:00, 18.35it/s]


Rec by latest item (0.009894429043406348, 0.04378504413243423, 0.013404707998392324) 
 Rec by N latest items (0.009712148255041896, 0.04301798834873037, 0.013045309557448564)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:05<00:00, 18.34it/s]


Rec by latest item (0.010813428018077057, 0.04704222819672648, 0.013679321209389875) 
 Rec by N latest items (0.01098811377359299, 0.04746504304389698, 0.014567431996900375)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:04<00:00, 18.36it/s]


Rec by latest item (0.010612159647591324, 0.04551032710362727, 0.015398602455821631) 
 Rec by N latest items (0.010665324877530953, 0.04522377620428991, 0.015352433978709404)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:03<00:00, 18.40it/s]


Early stopping
RESULT
Epoch 5, Rec by latest item: PR@1: 0.024228154786769453|REC@1: 0.006149617382432653|NDCG@1: 0.00043038519474930065
Epoch 5, Rec by 10 latest items: PR@1: 0.0266965404625375|REC@1: 0.0065156225765048755|NDCG@1: 0.0003797516424258535
Epoch 5, Rec by latest item: PR@5: 0.016815402726616205|REC@5: 0.019918779398371778|NDCG@5: 0.003818187268652324
Epoch 5, Rec by 10 latest items: PR@5: 0.016785022595222143|REC@5: 0.019743063790528103|NDCG@5: 0.003720155397081149
Epoch 5, Rec by latest item: PR@10: 0.012600159495689994|REC@10: 0.028425258279355714|NDCG@10: 0.006629287576578591
Epoch 5, Rec by 10 latest items: PR@10: 0.012303953214597794|REC@10: 0.02762779818972859|NDCG@10: 0.006413826194229203
Epoch 5, Rec by latest item: PR@15: 0.010939378979480371|REC@15: 0.03619559039361133|NDCG@15: 0.009476081148280182
Epoch 5, Rec by 10 latest items: PR@15: 0.010891277104773101|REC@15: 0.03543934105503367|NDCG@15: 0.009690581408944481
Epoch 5, Rec by latest item: PR@20: 0.010146963

Amazon Video Games

In [23]:
with open('data/amazon.pkl', 'rb') as f:
    dataset = pickle.load(f)
    

args = TrainArgs(output_model_path='models_attention/model_amazon_optuna', num_epochs=5, 
                 hidden_dims=64, batches_per_epoch=10000, device='cuda', k=20)


g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']
device = torch.device(args.device)
# Assign user and movie IDs and use them as features (to learn an individual trainable
# embedding for each entity)
g.nodes[user_ntype].data['id'] = torch.arange(g.number_of_nodes(user_ntype))
g.nodes[item_ntype].data['id'] = torch.arange(g.number_of_nodes(item_ntype))
# Prepare torchtext dataset and vocabulary
if item_texts is not None:
    fields = {}
    examples = []
    for key, texts in item_texts.items():
        fields[key] = data.Field(include_lengths=True, lower=True, batch_first=True)
    for i in range(g.number_of_nodes(item_ntype)):
        example = data.Example.fromlist(
            [item_texts[key][i] for key in item_texts.keys()],
            [(key, fields[key]) for key in item_texts.keys()])
        examples.append(example)
    textset = data.Dataset(examples, fields)
    for key, field in fields.items():
        field.build_vocab(getattr(textset, key))
        #field.build_vocab(getattr(textset, key), vectors='fasttext.simple.300d')
else:
    textset = None
# Sampler

In [25]:
def objective(trial):

    # 2. Suggest values of the hyperparameters using a trial object.
    n_layers = trial.suggest_int('n_layers', 1, 4)
    n_heads = trial.suggest_int('n_heads', 1, 3)
    hidden_dims = trial.suggest_int('hidden_dims', 64, 256)
    num_epochs = trial.suggest_int('num_epochs', 3, 10)
    learning_rate = trial.suggest_float("learning_rate_init", 1e-5, 1e-3)
    num_neighbors = trial.suggest_int('num_neighbors', 1, 15)
    
    batch_sampler = sampler_module.ItemToItemBatchSampler(
        g, user_ntype, item_ntype, args.batch_size)
    neighbor_sampler = sampler_module.NeighborSampler(
        g, user_ntype, item_ntype, args.random_walk_length,
        args.random_walk_restart_prob, args.num_random_walks, num_neighbors,
        n_layers)
    collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
        num_workers=args.num_workers)
    dataloader_test = DataLoader(
        torch.arange(g.number_of_nodes(item_ntype)),
        batch_size=args.batch_size,
        collate_fn=collator.collate_test,
        num_workers=args.num_workers)
    dataloader_it = iter(dataloader)
    
    model = PinSAGEModel(g, item_ntype, textset, hidden_dims, n_layers, n_heads).to(args.device)
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    layers = []

    for epoch_id in range(num_epochs):
        model.train()
        for batch_id in tqdm.trange(args.batches_per_epoch):
            pos_graph, neg_graph, blocks = next(dataloader_it)
            # Copy to GPU
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)

            loss = model(pos_graph, neg_graph, blocks).mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            
        model.eval()
        with torch.no_grad():
            item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size)
            h_item_batches = []
            for blocks in dataloader_test:
                for i in range(len(blocks)):
                    blocks[i] = blocks[i].to(device)

                h_item_batches.append(model.get_repr(blocks))
            h_item = torch.cat(h_item_batches, 0)
            metrics = evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size, args.n_latest_items)
            
    return metrics[0][2]

# 3. Create a study object and optimize the objective function.
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=1, timeout=10)

[32m[I 2022-05-27 15:37:14,371][0m A new study created in memory with name: no-name-fbfa7fa5-b993-4fd6-b396-25c493a12642[0m
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:09<00:00, 40.01it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:09<00:00, 40.05it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:11<00:00, 39.75it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:11<00:00, 39.78it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:02<00:00, 41.23it/s]
100%|███

In [26]:
args = TrainArgs(output_model_path='models_attention/model_amazon_optuna', num_epochs=8, 
                 hidden_dims=76, batches_per_epoch=10000, device='cuda', k=20,
                 lr = 0.000552, num_neighbors = 8, num_heads=2, num_layers=1)

In [27]:
batch_sampler = sampler_module.ItemToItemBatchSampler(
    g, user_ntype, item_ntype, args.batch_size)
neighbor_sampler = sampler_module.NeighborSampler(
    g, user_ntype, item_ntype, args.random_walk_length,
    args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors,
    args.num_layers)
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
dataloader = DataLoader(
    batch_sampler,
    collate_fn=collator.collate_train,
    num_workers=args.num_workers)
dataloader_test = DataLoader(
    torch.arange(g.number_of_nodes(item_ntype)),
    batch_size=args.batch_size,
    collate_fn=collator.collate_test,
    num_workers=args.num_workers)
dataloader_it = iter(dataloader)
# Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers, args.num_heads).to(device)
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr)

base_ndcg = 0
# For each batch of head-tail-negative triplets...
for epoch_id in range(args.num_epochs):
    model.train()
    for batch_id in tqdm.trange(args.batches_per_epoch):
        pos_graph, neg_graph, blocks = next(dataloader_it)
        # Copy to GPU
        for i in range(len(blocks)):
            blocks[i] = blocks[i].to(device)
        pos_graph = pos_graph.to(device)
        neg_graph = neg_graph.to(device)
        loss = model(pos_graph, neg_graph, blocks).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    #with open(args.output_model_path, 'wb') as f:
        #pickle.dump(model, f)
        
    # Evaluate
    model.eval()
    with torch.no_grad():
        item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size)
        h_item_batches = []
        for blocks in dataloader_test:
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            h_item_batches.append(model.get_repr(blocks))
        h_item = torch.cat(h_item_batches, 0)
        metrics = evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size, args.n_latest_items)
        if metrics[0][2] - base_ndcg > -0.001:
            old_metrics = metrics
            base_ndcg = metrics[0][2]
            torch.save(model.state_dict(), args.output_model_path)
        
            print('Rec by latest item', metrics[0],
                  '\n',
                  'Rec by N latest items', metrics[1])
        else:
            print('Early stopping')
            break

print('RESULT')
for k in [1, 5, 10, 15, 20]:
    metrics = evaluation.evaluate_nn(dataset, h_item, k, args.batch_size, args.n_latest_items)
    
    print(f'Epoch {epoch_id}, Rec by latest item: PR@{k}: {metrics[0][0]}|REC@{k}: {metrics[0][1]}|NDCG@{k}: {metrics[0][2]}')
    print(f'Epoch {epoch_id}, Rec by {args.n_latest_items} latest items: PR@{k}: {metrics[1][0]}|REC@{k}: {metrics[1][1]}|NDCG@{k}: {metrics[1][2]}')
        

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:16<00:00, 38.95it/s]


Rec by latest item (0.00835560174188448, 0.06615705923103982, 0.009057618506539878) 
 Rec by N latest items (0.008704473475851307, 0.06493858757309359, 0.009601862724181334)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:20<00:00, 38.34it/s]


Rec by latest item (0.008541171813143419, 0.068235198523233, 0.009342833919685127) 
 Rec by N latest items (0.008348178939034136, 0.06418700628322747, 0.008554205180549817)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:04<00:00, 40.96it/s]


Rec by latest item (0.009018705463183063, 0.06903598863446714, 0.010028264581840023) 
 Rec by N latest items (0.00854612034837699, 0.06440055759647947, 0.009023646637802096)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:04<00:00, 40.86it/s]


Rec by latest item (0.009770882818685938, 0.07458141564823727, 0.011400505240784736) 
 Rec by N latest items (0.009508610451306658, 0.069985502874879, 0.010655284377570278)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:05<00:00, 40.72it/s]


Rec by latest item (0.009513558986540244, 0.07465711270621028, 0.012349293204482355) 
 Rec by N latest items (0.010020783847981319, 0.07600981805376356, 0.012280222351615262)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:03<00:00, 41.13it/s]


Rec by latest item (0.009884699129058098, 0.07675211864808085, 0.012025480040597742) 
 Rec by N latest items (0.010080166270784159, 0.07528500854102727, 0.010411217621072187)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:17<00:00, 38.81it/s]


Rec by latest item (0.00940716547901844, 0.07334511486920466, 0.011033831581580684) 
 Rec by N latest items (0.009117676167854482, 0.0682706855519407, 0.008792502972168287)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [05:10<00:00, 32.24it/s]


Rec by latest item (0.0103226444972292, 0.07962849165103553, 0.013726529848540408) 
 Rec by N latest items (0.010268210609659864, 0.07617107080761024, 0.01173753361804386)
RESULT
Epoch 7, Rec by latest item: PR@1: 0.02093230403800475|REC@1: 0.00875798398135897|NDCG@1: 0.0006927949326999208
Epoch 7, Rec by 10 latest items: PR@1: 0.014251781472684086|REC@1: 0.005708347762858308|NDCG@1: 9.897070467141726e-05
Epoch 7, Rec by latest item: PR@5: 0.014845605700712252|REC@5: 0.029623199596183883|NDCG@5: 0.0032239283095019436
Epoch 7, Rec by 10 latest items: PR@5: 0.013687648456056746|REC@5: 0.025491047522049805|NDCG@5: 0.001742295796727916
Epoch 7, Rec by latest item: PR@10: 0.012752375296911705|REC@10: 0.04968138626679048|NDCG@10: 0.0073241993338093395
Epoch 7, Rec by 10 latest items: PR@10: 0.012227830562153206|REC@10: 0.04556517258844035|NDCG@10: 0.005062716689409276
Epoch 7, Rec by latest item: PR@15: 0.011332145684876959|REC@15: 0.06603788944401971|NDCG@15: 0.010562359022544526
Epoch 7, R