In [None]:
%pip install hydra-core==0.11.3
%pip install omegaconf==1.4.1
%pip install loguru==0.5.0



In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy
import numpy as np
import pandas as pd
import pickle
from collections import defaultdict, Counter
from tqdm import tqdm
import random

In [None]:
!git clone https://github.com/riktor/KGPL/

fatal: destination path 'KGPL' already exists and is not an empty directory.


In [None]:
!cd KGPL &&\
python preprocess/preprocess.py -d "movie" &&\
python preprocess/make_path_list.py lp_depth=6 dataset=music kg_path=data/music/kg_final.npy rating_path=data/music/ratings_final.npy num_neighbor_samples=32

!cd KGPL &&\
python preprocess/preprocess.py -d "music" &&\
python preprocess/make_path_list.py lp_depth=6 dataset=music kg_path=data/music/kg_final.npy rating_path=data/music/ratings_final.npy num_neighbor_samples=32

!cd KGPL &&\
python preprocess/preprocess.py -d "book" &&\
python preprocess/make_path_list.py lp_depth=6 dataset=book kg_path=data/book/kg_final.npy rating_path=data/book/ratings_final.npy num_neighbor_samples=8

reading item index to entity id file: data/movie/item_index2entity_id.txt ...
reading rating file ...
converting rating file ...
number of users: 6040
number of items: 2347
converting kg file ...
number of entities (containing items): 7008
number of relations: 7
done
adj_entity_path: data/music/adj_entity_6_32.npy
adj_relation_path: data/music/adj_relation_6_32.npy
data_path: data/music/fold1.pkl
dataset: music
kg_path: data/music/kg_final.npy
lp_depth: 6
num_neighbor_samples: 32
pathlist_path: data/music/path_list_6_32.pkl
rating_path: data/music/ratings_final.npy
reachable_items_path: data/music/reachable_items.pkl

[Parallel(n_jobs=32)]: Using backend MultiprocessingBackend with 32 concurrent workers.
[Parallel(n_jobs=32)]: Batch computation too fast (0.02765059471130371s.) Setting batch_size=2.
[Parallel(n_jobs=32)]: Done   8 tasks      | elapsed:    0.2s
[Parallel(n_jobs=32)]: Done  21 tasks      | elapsed:    0.6s
[Parallel(n_jobs=32)]: Done  34 tasks      | elapsed:    1.0s
[Par

In [None]:
# from KGPL.models.kgpl import KGPL_COT

In [153]:
torch.unique(torch.Tensor([1,2,3]), return_counts=True)

(tensor([1., 2., 3.]), tensor([1, 1, 1]))

In [249]:
class KGPL_Config():
  '''
  KGPL model configuration for each dataset.
  '''
  def __init__(self, dataset_name:str, model_type:str, neighbor_sample_size:int, dropout_rate:float, emb_dim:int, n_iter:int, plabel:dict, optimize:dict, log:dict, evaluate:dict, model:dict):
    self.dataset_name = dataset_name
    self.model_type = model_type
    self.neighbor_sample_size = neighbor_sample_size
    self.dropout_rate = dropout_rate
    self.emb_dim = emb_dim
    self.n_iter = n_iter
    self.plabel = plabel
    self.optimize = optimize
    self.log = log
    self.evaluate = evaluate
    self.model = model



def build_user_train_dict_from_tensor(ratings):
    user_train = {}
    for i in range(ratings.size(0)):
        user, item, rating = ratings[i]
        if rating >= 1:  # or whatever you define as "positive" interaction
            user = user.item()
            item = item.item()
            if user not in user_train:
                user_train[user] = []
            user_train[user].append(item)
    return user_train

# def compute_reachable_items_torch(args_list):
def compute_reachable_items_nodist(seed_items, kg):
    source = kg[:,0]
    target = kg[:,2]
    # print(source)
    # print(target)
    mask = (source.unsqueeze(1) == seed_items).any(dim=1) | (target.unsqueeze(1) == seed_items).any(dim=1)
    connected_edges = kg[mask]
    neighbors = torch.cat([connected_edges[:, 0], connected_edges[:, 2]])
    neighbors = neighbors[~torch.isin(neighbors, seed_items)]
    return torch.unique(neighbors)


def compute_reachable_items_torch(args_list):
    idd = {}
    _, _, dst_dict, item_freq, pn = args_list[0]
    for args in args_list:
        if args is None:
            continue
        user, seed_items, _, _, _ = args

        dst = Counter()
        for item in seed_items:
            if item in dst_dict:
                dst += dst_dict[item]

        if len(dst) != 0:
            udst = torch.tensor(list(dst.keys()), dtype=torch.long)
            F = torch.tensor(list(dst.values()), dtype=torch.float) ** pn

            mask = ~torch.isin(udst, torch.tensor(list(seed_items), dtype=torch.long))
            udst = udst[mask]
            F = F[mask]

            udst_set = set(udst.tolist())
            unreachable_items = [i for i in item_freq if i not in udst_set]
            udst = torch.cat([udst, torch.tensor(unreachable_items, dtype=torch.long)])

            F = torch.cat([F, torch.ones(len(unreachable_items)) * 0.5])

            sort_inds = torch.argsort(F)
            udst = udst[sort_inds]
            F = F[sort_inds]

            F = F / F.sum()
            F = torch.cumsum(F, dim=0)

            idd[user] = (udst, F)
    return idd

def set_item_candidates(
        self, n_user, n_item, train_data, eval_data, path_list_dict
    ):
        """Construct the sampling distrbiutions for negative/pseudo-labelled instances for each user
        """
        all_users = tuple(set(train_data[:, 0]))
        self.all_users = all_users

        self.n_item = n_item
        self.all_items = set(range(n_item))
        self.neg_c_dict_user = self._build_freq_dict(
            np.concatenate([train_data[:, 0], eval_data[:, 0]]), self.all_users
        )
        self.neg_c_dict_item = self._build_freq_dict(
            np.concatenate([train_data[:, 1], eval_data[:, 1]]), self.all_items
        )

        item_cands = tuple(self.neg_c_dict_item.keys())
        F = np.array(tuple(self.neg_c_dict_item.values())) ** self.cfg.plabel.neg_pn
        sort_inds = np.argsort(F)
        item_cands = [item_cands[i] for i in sort_inds]
        F = F[sort_inds]
        F = (F / F.sum()).cumsum()
        self.item_freq = (item_cands, F)

        for u, i in tqdm(train_data[:, 0:2]):
            self.user_seed_dict[u].add(i)

        path = hydra.utils.to_absolute_path(self.cfg.reachable_items_path)
        logger.info("calculating reachable items for users")
        self._setup_dst_dict(path_list_dict)
        item_dist_dict = {}
        src_itr = map(
            lambda iu: (
                all_users[iu],
                tuple(self.user_seed_dict[all_users[iu]]),
                self.dst_dict,
                self.neg_c_dict_item,
                self.cfg.plabel.pl_pn,
            ),
            range(len(all_users)),
        )
        grouped = grouper(self.cfg.plabel.chunk_size, src_itr, squash=set([2, 3]))
        with mp.Pool(self.cfg.plabel.par) as pool:
            for idd in pool.imap_unordered(compute_reachable_items_, grouped):
                item_dist_dict.update(idd)
        self.item_dist_dict = item_dist_dict


In [250]:
# def compute_reachable_items(ratings):
#     """
#     ratings: torch.Tensor of shape (N, 3), where each row is (user, item, rating)
#     """
#     reachable_items = dict()

#     for user, item, rating in ratings.tolist():
#         if rating == 1:  # positive interaction
#             if user not in reachable_items:
#                 reachable_items[user.item()] = set()
#             reachable_items[user.item()].add(item.item())

#     return reachable_items

def compute_reachable_items(users, user_seed_dict, path_dict, item_freq, power):
    """
    For each user, find reachable items via KG paths and compute a pseudo-label distribution.
    This is adapted from compute_reachable_items_.
    Returns: dict of user -> (item_list, cumulative_probs)
    """
    reachable = {}
    for user in users:
        seed_items = user_seed_dict.get(user, [])
        # Count reachable paths from each seed
        dst = Counter()
        for item in seed_items:
            dst.update(path_dict.get(item, {}))
        if not dst:
            continue
        udst = np.array(list(dst.keys()))
        freq = np.array(list(dst.values())) ** power
        # Remove seeds from candidates
        mask = ~np.isin(udst, list(seed_items))
        udst = udst[mask]
        freq = freq[mask]
        # Add 'unreachable' items with small weight
        all_items = np.array(list(item_freq.keys()))
        unreachable = np.setdiff1d(all_items, udst, assume_unique=True)
        freq = np.concatenate([freq, np.ones(len(unreachable)) * 0.5])
        udst = np.concatenate([udst, unreachable])
        # Sort and make CDF
        order = np.argsort(freq)
        udst = udst[order]; freq = freq[order]
        cdf = (freq / freq.sum()).cumsum()
        reachable[user] = (udst.tolist(), cdf.tolist())
    return reachable

In [251]:
for u, i in train_data[:, 0:2]:
    self.user_seed_dict[u].add(i)

NameError: name 'self' is not defined

In [262]:
class KGPL_Dataset(Dataset):
  '''
  Custom dataset class which includes all datasets and parameters per model.
  Specified under "data" directory
  '''

  base_data_path = 'KGPL/data/'

  def readjust_counts(self):
    unique_users = torch.unique(self.ratings[:,0], return_counts=True)
    self.users = unique_users[0]
    self.n_user = unique_users[1][0].item()
    self.n_item = torch.unique(self.ratings[:,1]).numel()
    self.reachable_items = compute_reachable_items_nodist(self.users, self.ratings)

  def __init__(self,dataset_name:str):
    self.dataset_name = dataset_name
    self.entity_adj = torch.from_numpy(np.load(self.base_data_path + self.dataset_name + '/adj_entity_6_32.npy'))
    self.relation_adj = torch.from_numpy(np.load(self.base_data_path + self.dataset_name + '/adj_relation_6_32.npy'))
    self.ratings = torch.from_numpy(np.load(self.base_data_path + self.dataset_name + '/ratings_final.npy'))
    self.path_list_dict = pickle.load(open(self.base_data_path + self.dataset_name + '/path_list_6_32.pkl', 'rb'))
    #kg is a 3-column matrix of undirected relations: (head, relation, tail)
    self.kg = np.load(self.base_data_path + self.dataset_name + '/kg_final.npy')
    self.readjust_counts()

  def __len__(self):
    # doing interactions
    return self.n_user

  def sample_positive(self, user):
        return random.choice(self.user_train[user])

  def sample_negative(self, user):
      seen = set(self.user_train[user])
      while True:
          item = random.randint(0, self.n_item - 1)
          if item not in seen:
              return item

  # def sample_pseudo_label(self, user):
  #       F = compute_reachable_items_torch(args_list)
  #       # udst, F = self.reachable_items[user]
  #       r = random.random()
  #       idx = torch.searchsorted(F, r, right=True).item()
  #       return udst[idx].item()
  def sample_pseudo_label(self, user):
        udst, F = self.reachable_items[user]
        r = random.random()
        idx = torch.searchsorted(F, r, right=True).item()
        return udst[idx].item()

  def __getitem__(self,idx):
    if self.train_set:
      user = self.users[idx]
      pos_item = self.sample_positive(idx)
      neg_item = self.sample_negative(idx)
      pseudo_label = self.sample_pseudo_label(user)
      return user, pos_item, neg_item, pseudo_label
    else:
      # still need to implement
      return None

  def _split_data(self, split_ratio=0.2):
    #split dataset
    n_ratings = len(self.ratings)
    split_indices = torch.randperm(n_ratings)[:int(n_ratings * split_ratio)]
    splitted_data = self.ratings[split_indices]
    rest_data = self.ratings[~torch.isin(torch.arange(n_ratings), split_indices)]
    #create new objects
    splitted_dataset, rest_dataset = deepcopy(self), deepcopy(self)
    splitted_dataset.ratings = splitted_data
    rest_dataset.ratings = rest_data
    splitted_dataset.readjust_counts()
    rest_dataset.readjust_counts()

    return rest_dataset, splitted_dataset

  def train_val_test_split(self):
    exp_dataset, test = self._split_data()
    train, val = exp_dataset._split_data()
    n_user = torch.unique(self.ratings[:,0]).numel()
    n_item = torch.unique(self.ratings[:,1]).numel()
    #readjust counts
    train.readjust_counts()
    val.readjust_counts()
    test.readjust_counts()
    train.user_train = build_user_train_dict_from_tensor(train.ratings)
    train.train_set = True
    return (n_user, n_item, train, val, test)

In [None]:
data = KGPL_Dataset('music').train_val_test_split()

cfg = KGPL_Config(
    'music',
    'KGPL_COT',
    neighbor_sample_size=32,
    dropout_rate=0.5,
    emb_dim=64,
    n_iter=1,
    plabel={},
    optimize={'iter_per_epoch':100, 'lr': 3e-3, 'batch_size':3333},
    log={'show_loss':True},
    evaluate={'user_num_topk':1000},
    model={'n_iter':1, 'neighbor_sample_size':32, 'dropout_rate':0.5}
)

In [258]:
def kgpl_loss(pos_scores, neg_scores, pseudo_scores):
    # BCE loss like TensorFlow version
    pos_labels = torch.ones_like(pos_scores)
    neg_labels = torch.zeros_like(neg_scores)
    pseudo_labels = torch.ones_like(pseudo_scores)

    loss = F.binary_cross_entropy_with_logits(pos_scores, pos_labels) + \
           F.binary_cross_entropy_with_logits(neg_scores, neg_labels) + \
           F.binary_cross_entropy_with_logits(pseudo_scores, pseudo_labels)
    return loss

In [259]:
class SumAggregatorWithDropout(nn.Module):
    def __init__(self, emb_dim, dropout_rate, activation, cfg):
        super().__init__()
        self.linear = nn.Linear(emb_dim * 2, emb_dim)
        self.activation = activation
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, self_vectors, neighbor_vectors, neighbor_relations, user_embeddings):
        # self_vectors: [batch_size, emb_dim]
        # neighbor_vectors: [batch_size, n_neighbors, emb_dim]
        neighbor_mean = neighbor_vectors.mean(dim=1)  # [batch_size, emb_dim]
        out = torch.cat([self_vectors, neighbor_mean], dim=-1)  # [batch_size, emb_dim * 2]
        out = self.linear(out)
        out = self.dropout(out)
        return self.activation(out)


class KGPLStudent(nn.Module):
    def __init__(self, cfg, n_user, n_entity, n_relation, adj_entity, adj_relation, path_list_dict, name, eval_mode=False):
        super().__init__()
        self.cfg = cfg
        self.name = name
        self.n_user = n_user
        self.n_entity = n_entity
        self.n_relation = n_relation
        self.batch_size = cfg.optimize['batch_size']
        self.adj_entity = adj_entity
        self.adj_relation = adj_relation
        self.path_list_dict = path_list_dict
        self.eval_mode = eval_mode

        self.user_emb_matrix = nn.Embedding(n_user, cfg.emb_dim)
        self.entity_emb_matrix = nn.Embedding(n_entity, cfg.emb_dim)
        self.relation_emb_matrix = nn.Embedding(n_relation, cfg.emb_dim)

        self.aggregators = nn.ModuleList([
            SumAggregatorWithDropout(cfg.emb_dim, cfg.dropout_rate, activation=nn.Tanh(), cfg=cfg)
            if i == cfg.n_iter - 1 else
            SumAggregatorWithDropout(cfg.emb_dim, cfg.dropout_rate, activation=nn.LeakyReLU(), cfg=cfg)
            for i in range(cfg.n_iter)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.user_emb_matrix.weight)
        nn.init.xavier_uniform_(self.entity_emb_matrix.weight)
        nn.init.xavier_uniform_(self.relation_emb_matrix.weight)

    def forward(self, user_indices, item_indices):
        user_embeds = self.user_emb_matrix(user_indices)
        item_embeds = self.get_item_embeddings(item_indices)
        scores = (user_embeds * item_embeds).sum(dim=-1)
        return scores

    def get_item_embeddings(self, item_indices):
        entities = [item_indices.unsqueeze(1)]
        relations = []
        for i in range(self.cfg.n_iter):
            neighbor_entities = self.adj_entity[entities[-1]].view(item_indices.size(0), -1)
            neighbor_relations = self.adj_relation[entities[-1]].view(item_indices.size(0), -1)
            entities.append(neighbor_entities)
            relations.append(neighbor_relations)

        entity_vectors = [self.entity_emb_matrix(e) for e in entities]
        relation_vectors = [self.relation_emb_matrix(r) for r in relations]

        for i in range(self.cfg.n_iter):
            new_vectors = []
            for hop in range(self.cfg.n_iter - i):
                batch_size, neighbor_size, emb_dim = entity_vectors[hop+1].size(0), entity_vectors[hop+1].size(1) // self.cfg.model.neighbor_sample_size, entity_vectors[hop+1].size(2)
                neighbor_vecs = entity_vectors[hop+1].view(batch_size, neighbor_size, self.cfg.model.neighbor_sample_size, emb_dim)
                relation_vecs = relation_vectors[hop].view(batch_size, neighbor_size, self.cfg.model.neighbor_sample_size, emb_dim)
                vector = self.aggregators[i](
                    self_vectors=entity_vectors[hop],
                    neighbor_vectors=neighbor_vecs,
                    neighbor_relations=relation_vecs,
                    user_embeddings=None  # optional
                )
                new_vectors.append(vector)
            entity_vectors = new_vectors

        return entity_vectors[0]

In [260]:
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0

    for batch in dataloader:
        user, pos_item, neg_item, pseudo_item = batch
        user = user.to(device)
        pos_item = pos_item.to(device)
        neg_item = neg_item.to(device)
        pseudo_item = pseudo_item.to(device)

        optimizer.zero_grad()

        pos_scores = model(user, pos_item)
        neg_scores = model(user, neg_item)
        pseudo_scores = model(user, pseudo_item)

        loss = kgpl_loss(pos_scores, neg_scores, pseudo_scores)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [261]:
train_dataset = data[2]

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,  # optional for faster loading
    pin_memory=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = KGPLStudent(cfg, 1872, 3846, 60, train_dataset.entity_adj, train_dataset.relation_adj, train_dataset.path_list_dict, name='student').to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.optimize['lr'])

# Train
for epoch in tqdm(range(10)):
    loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch} | Loss: {loss:.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]


IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "<ipython-input-252-bbffcb7c2058>", line 57, in __getitem__
    pseudo_label = self.sample_pseudo_label(user)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-252-bbffcb7c2058>", line 47, in sample_pseudo_label
    udst, F = self.reachable_items[user]
              ~~~~~~~~~~~~~~~~~~~~^^^^^^
IndexError: index 9 is out of bounds for dimension 0 with size 0


# JUNK BELOW

In [130]:
import itertools


def grouper(n, iterable, squash=None):
    it = iter(iterable)
    while True:
        if squash:
            chunk = [
                [None if (j != 0 and i in squash) else el[i] for i in range(len(el))]
                for j, el in enumerate(itertools.islice(it, n))
            ]
        else:
            chunk = list(itertools.islice(it, n))

        if not chunk:
            return
        elif len(chunk) != n:
            chunk += [None] * (n - len(chunk))
        yield chunk

In [131]:
def compute_reachable_items_(args_list):
    """Construct the sampling distributions based on paths in KG.
    Args:
        args_list: list of list of arguments. Each arguments' list must contains;
        (1) user_id;
        (2) user's interacted item ids (seed items);
        (3) item-to-(item, #paths) dict found in the BFS (start and end points of some paths);
        (4) item-to-frequency dict;
        (5) power coefficient to control the skewness of sampling distributions
    Returns:
        dict in which (key, value) = (item list, np.array of sampling distribution).
        sampling distribution is transformed to CDF for fast sampling.
    """
    idd = {}
    _, _, dst_dict, item_freq, pn = args_list[0]
    for args in args_list:
        if args is None:
            continue
        user, seed_items, _, _, _ = args

        # Collect user's reachable items with the number of reachable paths
        dst = Counter()
        for item in seed_items:
            if item in dst_dict:
                dst += dst_dict[item]

        if len(dst) != 0:
            # Unique reachable items for the user
            udst = np.array(tuple(dst.keys()))

            # Histogram of paths with power transform
            F = np.array(tuple(dst.values())) ** pn

            # Remove the seed (positve) items
            inds = ~np.isin(udst, seed_items)
            udst = udst[inds]
            F = F[inds]

            # Compute unreachable items and concat those to the end of item lists
            udst = set(udst)
            unreachable_items = [i for i in item_freq if i not in udst]
            udst = list(udst) + unreachable_items

            # For unreachable items, assume 0.5 virtual paths for full support
            F = np.concatenate([F, np.ones(len(unreachable_items)) * 0.5])

            # Transform histogram to CDF
            sort_inds = np.argsort(F)
            udst = [udst[i] for i in sort_inds]
            F = F[sort_inds]
            F = (F / np.sum(F)).cumsum()
            idd[user] = (udst, F)
    return idd

In [132]:
def compute_user_unseen_items_(args_list):
    user_unseen_items = {}
    all_items, _, _ = args_list[0]
    for args in args_list:
        if args is None:
            continue
        _, user, seed_items, _, _ = args
        unseen_items = tuple(all_items - seed_items)
        user_unseen_items[user] = (unseen_items, None)
    return user_unseen_items

In [141]:
class KaPLMixin:
    def __init__(self, cfg, n_entity, n_relation, adj_entity, adj_relation, path_list_dict, eval_mode=False):
        self.cfg = cfg
        self.n_entity = n_entity
        self.n_relation = n_relation
        self.user_seed_dict = defaultdict(set)
        self.item_dist_dict = {}
        self.cand_uinds = None
        self.cand_iinds = None

    def _build_freq_dict(self, seq, all_candidates):
        _freq = Counter(seq)
        for i in all_candidates:
            if i not in _freq:
                _freq[i] += 1
        freq = [_freq[i] for i in all_candidates]
        return dict(zip(all_candidates, freq))

    def set_item_candidates(self, n_user, n_item, train_data, eval_data, path_list_dict):
        all_users = tuple(set(train_data[:, 0]))
        self.all_users = all_users
        self.n_item = n_item
        self.all_items = set(range(n_item))

        self.neg_c_dict_user = self._build_freq_dict(
            np.concatenate([train_data[:, 0], eval_data[:, 0]]), self.all_users
        )
        self.neg_c_dict_item = self._build_freq_dict(
            np.concatenate([train_data[:, 1], eval_data[:, 1]]), self.all_items
        )

        item_cands = tuple(self.neg_c_dict_item.keys())
        F = np.array(list(self.neg_c_dict_item.values())) ** self.cfg.plabel['neg_pn']
        sort_inds = np.argsort(F)
        item_cands = [item_cands[i] for i in sort_inds]
        F = F[sort_inds]
        F = F / F.sum()
        F = np.cumsum(F)
        self.item_freq = (item_cands, F)

        for u, i in train_data[:, 0:2]:
            self.user_seed_dict[u].add(i)

        self._setup_dst_dict(path_list_dict)
        item_dist_dict = {}

        src_itr = [
            (all_users[iu], tuple(self.user_seed_dict[all_users[iu]]), self.dst_dict,
             self.neg_c_dict_item, self.cfg.plabel.pl_pn)
            for iu in range(len(all_users))
        ]

        # Using multiprocess (if you want parallelism) - here shown sequentially
        for idd in map(compute_reachable_items_, [src_itr[i:i+self.cfg.plabel.chunk_size] for i in range(0, len(src_itr), self.cfg.plabel.chunk_size)]):
            item_dist_dict.update(idd)

        self.item_dist_dict = item_dist_dict

    def _setup_dst_dict(self, path_list_dict):
        dst_dict = {}
        for item, paths in path_list_dict.items():
            dst = []
            for p in paths:
                dst.append(p[-1])
            dst_dict[item] = Counter(dst)
        self.dst_dict = dst_dict

In [142]:
import torch.nn as nn
import torch.nn.functional as F

class KGPLStudent(nn.Module, KaPLMixin):
    def __init__(self, cfg, n_user, n_entity, n_relation, adj_entity, adj_relation, path_list_dict, name, eval_mode=False):
        super().__init__()
        self.name = name
        self.cfg = cfg
        self.n_user = n_user
        self.n_entity = n_entity
        self.n_relation = n_relation
        self.batch_size = cfg.optimize.batch_size

        # Embedding matrices
        self.user_emb_matrix = nn.Embedding(n_user, cfg.emb_dim)
        self.entity_emb_matrix = nn.Embedding(n_entity, cfg.emb_dim)
        self.relation_emb_matrix = nn.Embedding(n_relation, cfg.emb_dim)

        # Neighborhood info
        self.adj_entity = torch.tensor(adj_entity, dtype=torch.long)
        self.adj_relation = torch.tensor(adj_relation, dtype=torch.long)

        KaPLMixin.__init__(self, cfg, n_entity, n_relation, adj_entity, adj_relation, path_list_dict, eval_mode)

    def forward(self, user_indices, item_indices):
        user_embeddings = self.user_emb_matrix(user_indices)
        entities, relations = self.get_neighbors(item_indices)
        item_embeddings = self.aggregate(entities, relations)
        scores = (user_embeddings * item_embeddings).sum(dim=1)
        return torch.sigmoid(scores)

    def get_neighbors(self, seeds):
        entities = [seeds.unsqueeze(1)]
        relations = []
        for _ in range(self.cfg.n_iter):
            neighbor_entities = self.adj_entity[entities[-1]].view(seeds.size(0), -1)
            neighbor_relations = self.adj_relation[entities[-1]].view(seeds.size(0), -1)
            entities.append(neighbor_entities)
            relations.append(neighbor_relations)
        return entities, relations

    def aggregate(self, entities, relations):
        entity_vectors = [self.entity_emb_matrix(entity) for entity in entities]
        relation_vectors = [self.relation_emb_matrix(rel) for rel in relations]
        for i in range(self.cfg.n_iter):
            if i == self.cfg.n_iter - 1:
                act = torch.tanh
            else:
                act = F.leaky_relu
            entity_vectors_next_iter = []
            for hop in range(self.cfg.n_iter - i):
                self_vector = entity_vectors[hop]
                neighbor_vector = entity_vectors[hop + 1].view(self_vector.size(0), -1, self.cfg.neighbor_sample_size, self.cfg.emb_dim)
                neighbor_relation = relation_vectors[hop].view(self_vector.size(0), -1, self.cfg.neighbor_sample_size, self.cfg.emb_dim)
                # Simple sum aggregation
                vector = act(self_vector + neighbor_vector.mean(2) + neighbor_relation.mean(2))
                entity_vectors_next_iter.append(vector)
            entity_vectors = entity_vectors_next_iter
        res = entity_vectors[0].view(self.batch_size, self.cfg.emb_dim)
        return res

In [None]:
class KaPLMixin(object):
    def __init__(
        self, cfg, n_entity, n_relation, adj_entity, adj_relation, path_list_dict, eval_mode=False
    ):
        self.user_seed_dict = defaultdict(set)
        self.item_dist_dict = {}
        self.cand_uinds = None
        self.cand_iinds = None
        self.n_entity = n_entity
        self.n_relation = n_relation

    def _build_freq_dict(self, seq, all_candidates):
        _freq = Counter(seq)
        for i in all_candidates:
            if i not in _freq:
                _freq[i] += 1
        freq = [_freq[i] for i in all_candidates]
        return dict(zip(all_candidates, freq))

    def set_item_candidates(
        self, n_user, n_item, train_data, eval_data, path_list_dict
    ):
        """
        Construct the sampling distrbiutions for negative/pseudo-labelled instances for each user
        """
        all_users = tuple(set(train_data[:, 0]))
        self.all_users = all_users

        self.n_item = n_item
        self.all_items = set(range(n_item))
        self.neg_c_dict_user = self._build_freq_dict(
            np.concatenate([train_data[:, 0], eval_data[:, 0]]), self.all_users
        )
        self.neg_c_dict_item = self._build_freq_dict(
            np.concatenate([train_data[:, 1], eval_data[:, 1]]), self.all_items
        )

        item_cands = tuple(self.neg_c_dict_item.keys())
        F = np.array(tuple(self.neg_c_dict_item.values())) ** self.cfg.plabel.neg_pn
        sort_inds = np.argsort(F)
        item_cands = [item_cands[i] for i in sort_inds]
        F = F[sort_inds]
        F = (F / F.sum()).cumsum()
        self.item_freq = (item_cands, F)

        for u, i in tqdm(train_data[:, 0:2]):
            self.user_seed_dict[u].add(i)

        # path = hydra.utils.to_absolute_path(self.cfg.reachable_items_path)
        print("calculating reachable items for users")
        self._setup_dst_dict(path_list_dict)
        item_dist_dict = {}
        src_itr = map(
            lambda iu: (
                all_users[iu],
                tuple(self.user_seed_dict[all_users[iu]]),
                self.dst_dict,
                self.neg_c_dict_item,
                self.cfg.plabel.pl_pn,
            ),
            range(len(all_users)),
        )
        grouped = grouper(self.cfg.plabel.chunk_size, src_itr, squash=set([2, 3]))
        with mp.Pool(self.cfg.plabel.par) as pool:
            for idd in pool.imap_unordered(compute_reachable_items_, grouped):
                item_dist_dict.update(idd)
        self.item_dist_dict = item_dist_dict

    def _setup_dst_dict(self, path_list_dict):
        """
        Transform path representations:
        `list of nodes` to `dictionaly of source to sink (dst_dict)`
        """

        print("setup dst dict...")
        dst_dict = {}
        for item in tqdm(path_list_dict):
            dst = []
            paths = path_list_dict[item]
            for i, p in enumerate(paths):
                dst.append(p[-1])
            dst_dict[item] = Counter(dst)
        print("start updating path info...")
        self.dst_dict = dst_dict
        print.info("path info updated.")

    # def _get_user_rel_scores(self, sess, users):
    def _get_user_rel_scores(self, users)
        uembs = self.user_embeddings()

        # self.user_indices = users

        uembs =

        # sess.run(
        #     self.user_embeddings, feed_dict={self.user_indices: users, self.dropout_rate: 0.0}
        # )  # nu, legth
        # rembs = sess.run(self.relation_emb_matrix)  # nr, length
        rembs = self.relation_emb_matrix()

        return np.dot(uembs, rembs.T)  # nu, nr

    def _get_mini_batch_pl(self, sess, users):
        """
        Create pseudo-labelled instances for users
        """
        pl_users, pl_items = [], []
        ind = 0
        cands, freq_F = self.item_freq
        while True:
            u = users[ind % len(users)]
            ind += 1
            if u in self.item_dist_dict and len(self.item_dist_dict[u][0]) != 0:
                udst, F = self.item_dist_dict[u]
                i = udst[np.searchsorted(F, torch.rand(1).item())] #changed random to torch
            else:
                while True:
                    i = cands[np.searchsorted(freq_F, torch.rand(1).item())] #changed random to torch
                    if i not in self.user_seed_dict[u]:
                        break
            pl_users.append(u)
            pl_items.append(i)
            if len(pl_users) == len(users):
                break

        pl_users_pad = list(pl_users) + [0] * (self.batch_size - len(pl_users))
        pl_items_pad = list(pl_items) + [0] * (self.batch_size - len(pl_items))

        # start taylor added for PyTorch
        # self.user_indices = pl_users_pad
        # self.item_indices = pl_items_pad
        # self.scores_normalized = self._build_model(n_user, n_entity, n_relation)
        # end taylor addded pytorch

        pl_labels_pad = sess.run(
            self.scores_normalized,
            feed_dict={
                self.user_indices: pl_users_pad,
                self.item_indices: pl_items_pad,
                self.dropout_rate: 0.0,
            },
        )
        pl_users = pl_users_pad[: len(pl_users)]
        pl_items = pl_items_pad[: len(pl_items)]
        pl_labels = pl_labels_pad[: len(pl_users)]
        return pl_users, pl_items, pl_labels


## Train.py

In [124]:
class KGPL_COT():
  def __init__(self, *args):
    pass

def topk_settings(*args, **kwargs):
  pass

In [126]:
def train(cfg, data):
    (n_user, n_item, train_data, eval_data, test_data) = data

    adj_entity = train_data.entity_adj
    adj_relation = train_data.relation_adj
    n_entity = adj_entity.shape[0]
    n_relation = len(torch.unique(adj_relation.reshape(-1)))
    path_list_dict = train_data.path_list_dict

    print(f"num train records: {len(train_data)}")
    print(f"num adj entities: {len(adj_entity)}, num entities: {n_entity}")
    print(f"num adj relations: {len(adj_relation)}, num relations: {n_relation}")

    model = KGPL_COT(
        cfg,
        n_user,
        n_item,
        n_entity,
        n_relation,
        adj_entity,
        adj_relation,
        path_list_dict,
        train_data,
        eval_data,
    )

    _pos_inds = train_data[:, 2] == 1
    train_data = train_data[_pos_inds]
    print("model type:", cfg.model_type)

    topk_config = topk_settings(
        train_data,
        eval_data,
        test_data,
        n_item,
        test_mode=True,
        user_num=cfg.evaluate['user_num_topk'],
    )

    batch_size = cfg.optimize['batch_size']

    print(model)

train(cfg, data)

num train records: 27102
num adj entities: 9366, num entities: 9366
num adj relations: 9366, num relations: 60
model type: KGPL_COT
<__main__.KGPL_COT object at 0x7d6c76b49dd0>
