In [None]:
# %pip install hydra-core==0.11.3
# %pip install omegaconf==1.4.1
# %pip install loguru==0.5.0

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy
import numpy as np
import pandas as pd
import pickle
from collections import defaultdict, Counter
from tqdm import tqdm
import random
import multiprocessing as mp
import itertools
import torch.nn as nn
import torch.nn.functional as F

In [6]:
# !git clone https://github.com/riktor/KGPL/

In [7]:
# !cd KGPL &&\
# python preprocess/preprocess.py -d "music" &&\
# python preprocess/make_path_list.py lp_depth=6 dataset=music kg_path=data/music/kg_final.npy rating_path=data/music/ratings_final.npy num_neighbor_samples=32

## Helper Functions

In [8]:
def setup_dst_dict(path_list_dict):
        """
        Transform path representations:
        `list of nodes` to `dictionary of source to sink (dst_dict)`
        """
        print("Setting up dst dict...")
        dst_dict = {}
        for item in tqdm(path_list_dict):
            dst = []
            paths = path_list_dict[item]
            for p in paths:
                dst.append(p[-1])
            dst_dict[item] = Counter(dst)

        print("Start updating path info...")
        print("Path info updated.")
        return dst_dict

def grouper(n, iterable, squash=None):
    it = iter(iterable)
    while True:
        if squash:
            chunk = [
                [None if (j != 0 and i in squash) else el[i] for i in range(len(el))]
                for j, el in enumerate(itertools.islice(it, n))
            ]
        else:
            chunk = list(itertools.islice(it, n))

        if not chunk:
            return
        elif len(chunk) != n:
            chunk += [None] * (n - len(chunk))
        yield chunk

def compute_reachable_items_(args_list):
  """Construct the sampling distributions based on paths in KG.
  Args:
      args_list: list of list of arguments. Each arguments' list must contains;
      (1) user_id;
      (2) user's interacted item ids (seed items);
      (3) item-to-(item, #paths) dict found in the BFS (start and end points of some paths);
      (4) item-to-frequency dict;
      (5) power coefficient to control the skewness of sampling distributions
  Returns:
      dict in which (key, value) = (item list, np.array of sampling distribution).
      sampling distribution is transformed to CDF for fast sampling.
  """
  idd = {}
  _, _, dst_dict, item_freq, pn = args_list[0]
  for args in args_list:
      if args is None:
          continue
      user, seed_items, _, _, _ = args

      # print('User:', user)
      # print('Seed Items:', seed_items)

      # Collect user's reachable items with the number of reachable paths
      dst = Counter()
      for item in seed_items:
          if item in dst_dict:
              dst += dst_dict[item]

      if len(dst) != 0:
          # Unique reachable items for the user
          udst = np.array(tuple(dst.keys()))

          # Histogram of paths with power transform
          F = np.array(tuple(dst.values())) ** pn

          # Remove the seed (positve) items
          inds = ~np.isin(udst, seed_items)
          udst = udst[inds]
          F = F[inds]

          # Compute unreachable items and concat those to the end of item lists
          udst = set(udst)
          unreachable_items = [i for i in item_freq if i not in udst]
          udst = list(udst) + unreachable_items

          # For unreachable items, assume 0.5 virtual paths for full support
          F = np.concatenate([F, np.ones(len(unreachable_items)) * 0.5])

          # Transform histogram to CDF
          sort_inds = np.argsort(F)
          udst = [udst[i] for i in sort_inds]
          F = F[sort_inds]
          F = (F / np.sum(F)).cumsum()
          idd[user] = (udst, F)
  return idd

# HELPER FUNCTION STRAIGHT FROM CHATGPT
def build_user_train_dict_from_tensor(ratings):
    user_train = defaultdict(list)
    for user, item, rating in ratings:
        if rating >= 1:  # or whatever you define as "positive" interaction
            user = user.item()
            item = item.item()
            # if user not in user_train:
            #     user_train[user] = []
            user_train[user].append(item)
    return user_train

## Main Dataset Class

In [126]:
class KGPL_Dataset(Dataset):
  '''
  Custom dataset class which includes all datasets and parameters per model.
  Specified under "data" directory
  '''

  base_data_path = 'KGPL/data/'
  datasets = {}
  n_user = 0
  n_item = 0

  class cfg:
    plabel_lp_depth = 6
    plabel_par = 16
    plabel_chunk_size = 250
    plabel_neg_pn = 0.1
    plabel_pl_pn = 1e-3

  @classmethod
  def _init_class(cls, dataset_name):
    cls.dataset_name = dataset_name
    print('Loading Entity Adjacencies...')
    cls.adj_entity = torch.from_numpy(np.load(cls.base_data_path + dataset_name + '/adj_entity_6_32.npy'))
    print('Loading Relation Adjacencies...')
    cls.adj_relation = torch.from_numpy(np.load(cls.base_data_path + dataset_name + '/adj_relation_6_32.npy'))
    print('Loading Ratings...')
    cls.ratings = torch.from_numpy(np.load(cls.base_data_path + dataset_name + '/ratings_final.npy'))
    print('Loading Path List...')
    cls.path_list_dict = pickle.load(open(cls.base_data_path + dataset_name + '/path_list_6_32.pkl', 'rb'))
    print('Loading Distances...')
    cls.dst_dict = setup_dst_dict(cls.path_list_dict)

  def _split_data(self, split_ratio=0.2):
    #split dataset
    n_ratings = len(self.ratings)
    split_indices = torch.randperm(n_ratings)[:int(n_ratings * split_ratio)]
    splitted_data = self.ratings[split_indices]
    rest_data = self.ratings[~torch.isin(torch.arange(n_ratings), split_indices)]
    #create new objects
    splitted_dataset, rest_dataset = deepcopy(self), deepcopy(self)
    splitted_dataset.ratings = splitted_data
    rest_dataset.ratings = rest_data
    return rest_dataset, splitted_dataset

  def _readjust_counts(self):
    self.users = torch.unique(self.ratings[:,0])
    self.items = torch.unique(self.ratings[:,1])

  def _train_val_test_split(self):
    exp_dataset, test = self._split_data()
    train, val = exp_dataset._split_data()
    train.train_set = True
    val.train_set = False
    test.train_set = False
    for fold in (train, val, test):
      fold._readjust_counts()
    n_user = torch.unique(self.ratings[:,0]).numel()
    n_item = torch.unique(self.ratings[:,1]).numel()
    train._readjust_counts()
    val._readjust_counts()
    test._readjust_counts()
    self.__class__.n_user = n_user
    self.__class__.n_item = n_item
    self.__class__.datasets['train'] = train
    self.__class__.datasets['val'] = val
    self.__class__.datasets['test'] = test


  @staticmethod
  def _build_freq_dict(seq, all_candidates): # need all_candidates
    _freq = Counter(seq)
    for i in all_candidates:
        if i not in _freq:
            _freq[i] += 1
    freq = [_freq[i] for i in all_candidates]
    return dict(zip(all_candidates, freq))

  #### THIS IS THE BIG ONE #####
  def set_item_candidates(self):
        """
        Construct the sampling distrbiutions for negative/pseudo-labelled instances for each user
        """
        train_data = self.datasets['train'].ratings
        eval_data = self.datasets['val'].ratings
        # all_users = tuple(set(train_data[:, 0])) # train data sets the users - not predicting unseen users
        all_users = torch.unique(train_data[:,0])
        self.all_users = all_users
        self.all_items = set(torch.arange(self.__class__.n_item))
        self.neg_c_dict_user = self._build_freq_dict(
            torch.concat([train_data[:, 0], eval_data[:, 0]]), self.all_users
        )

        print('Neg C Dict User:', len(self.neg_c_dict_user))

        self.neg_c_dict_item = self._build_freq_dict(
            np.concatenate([train_data[:, 1], eval_data[:, 1]]), self.all_items
        )

        print('Neg C Dict Item:', len(self.neg_c_dict_item))

        item_cands = tuple(self.neg_c_dict_item.keys())
        F = np.array(tuple(self.neg_c_dict_item.values())) ** self.cfg.plabel_neg_pn
        sort_inds = np.argsort(F)
        item_cands = [item_cands[i] for i in sort_inds]
        F = F[sort_inds]
        F = (F / F.sum()).cumsum()
        self.item_freq = (item_cands, F)

        for u, i in tqdm(train_data[:, 0:2]):
            self.user_seed_dict[u.item()].add(i.item())

        # path = hydra.utils.to_absolute_path(self.cfg.reachable_items_path)
        # logger.info("calculating reachable items for users")
        # self._setup_dst_dict(self.path_list_dict)
        item_dist_dict = {}
        src_itr = map(
            lambda iu: (
                all_users[iu].item(),
                tuple(self.user_seed_dict[all_users[iu].item()]),
                self.dst_dict,
                self.neg_c_dict_item,
                self.cfg.plabel_pl_pn,
            ),
            range(len(all_users)),
        )

        # print('SRC ITR:', len(src_itr))

        grouped = grouper(self.cfg.plabel_chunk_size, src_itr, squash=set([2, 3]))

        # with mp.Pool(self.cfg.plabel_par) as pool:
        #     for idd in pool.imap_unordered(compute_reachable_items_, grouped):
        #         item_dist_dict.update(idd)
        #         print(idd)
        # self.item_dist_dict = item_dist_dict
        print('Populating item dist dict...')
        item_dist_dict = {}
        for group in tqdm(grouped):
            # print('Group sample:', group[0])
            idd = compute_reachable_items_(group)
            item_dist_dict.update(idd)
        self.item_dist_dict = item_dist_dict

  def __init__(self, dataset_name):
    self._init_class(dataset_name)
    # self_ratings = self.all_ratings
    self.user_seed_dict = defaultdict(set)
    self._train_val_test_split()
    self.__class__.datasets['train'].set_item_candidates()
    self.__class__.datasets['train'].user_train = build_user_train_dict_from_tensor(self.__class__.datasets['train'].ratings)

  def sample_positive(self, user):
        choice_seq = self.ratings[(self.ratings[:,0] == user) & (self.ratings[:,2]) >= 1][:,1]
        if choice_seq.numel() == 0:
          print('Something went wrong.')
        return random.choice(choice_seq)

  def sample_negative(self, user):
      seen = set(self.user_train[user])
      while True:
          item = random.randint(0, self.n_item - 1)
          if item not in seen:
              return item

  def sample_pseudo_label(self, user):
        udst, F = self.item_dist_dict[user.item()]
        r = random.random()
        F = torch.as_tensor(F)
        idx = torch.searchsorted(F, r, right=True)
        return udst[idx]

  def __getitem__(self,idx):
    if self.train_set:
      user = self.users[idx-1]
      pos_item = self.sample_positive(user)
      neg_item = self.sample_negative(user)
      pseudo_label = self.sample_pseudo_label(user)
      return user, torch.tensor(pos_item), torch.tensor(neg_item), torch.tensor(pseudo_label)
    else:
      # still need to implement
      return None

  def __len__(self):
    return self.n_user

KGPL_Dataset('music')



Loading Entity Adjacencies...
Loading Relation Adjacencies...
Loading Ratings...
Loading Path List...
Loading Distances...
Setting up dst dict...


100%|██████████| 3846/3846 [00:00<00:00, 13403.91it/s]


Start updating path info...
Path info updated.
Neg C Dict User: 1872
Neg C Dict Item: 3846


100%|██████████| 27102/27102 [00:00<00:00, 185332.14it/s]


Populating item dist dict...


8it [00:04,  1.93it/s]


<__main__.KGPL_Dataset at 0x787f18716a50>

In [127]:
#testing getitem
KGPL_Dataset.datasets['train'][55]

  return user, torch.tensor(pos_item), torch.tensor(neg_item), torch.tensor(pseudo_label)


(tensor(54), tensor(302), tensor(1150), tensor(1741))

# Model

In [128]:
class KGPL_Config():
  '''
  KGPL model configuration for each dataset.
  '''
  def __init__(self, dataset_name:str, model_type:str, neighbor_sample_size:int, dropout_rate:float, emb_dim:int, n_iter:int, plabel:dict, optimize:dict, log:dict, evaluate:dict, model:dict):
    self.dataset_name = dataset_name
    self.model_type = model_type
    self.neighbor_sample_size = neighbor_sample_size
    self.dropout_rate = dropout_rate
    self.emb_dim = emb_dim
    self.n_iter = n_iter
    self.plabel = plabel
    self.optimize = optimize
    self.log = log
    self.evaluate = evaluate
    self.model = model

In [129]:
model_cfg = KGPL_Config(
    'music',
    'KGPL_COT',
    neighbor_sample_size=32,
    dropout_rate=0.5,
    emb_dim=64,
    n_iter=1,
    plabel={},
    optimize={'iter_per_epoch':100, 'lr': 3e-3, 'batch_size':3333},
    log={'show_loss':True},
    evaluate={'user_num_topk':1000},
    model={'n_iter':1, 'neighbor_sample_size':32, 'dropout_rate':0.5}
)

In [130]:
def kgpl_loss(pos_scores, neg_scores, pseudo_scores):
    # BCE loss like TensorFlow version
    pos_labels = torch.ones_like(pos_scores)
    neg_labels = torch.zeros_like(neg_scores)
    pseudo_labels = torch.ones_like(pseudo_scores)
    loss = F.binary_cross_entropy_with_logits(pos_scores, pos_labels) + \
           F.binary_cross_entropy_with_logits(neg_scores, neg_labels) + \
           F.binary_cross_entropy_with_logits(pseudo_scores, pseudo_labels)
    return loss

In [131]:
class SumAggregatorWithDropout(nn.Module):
    def __init__(self, emb_dim, dropout_rate, activation, cfg):
        super().__init__()
        self.linear = nn.Linear(emb_dim * 2, emb_dim)
        self.activation = activation
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, self_vectors, neighbor_vectors, neighbor_relations, user_embeddings):
        # self_vectors: [batch_size, emb_dim]
        # neighbor_vectors: [batch_size, n_neighbors, emb_dim]
        # neighbor_mean = neighbor_vectors.mean(dim=1)  # [batch_size, emb_dim]
        neighbor_mean = neighbor_vectors.mean(dim=[1, 2])
        out = torch.cat([self_vectors, neighbor_mean], dim=-1)  # [batch_size, emb_dim * 2]
        out = self.linear(out)
        out = self.dropout(out)
        return self.activation(out)


class KGPLStudent(nn.Module):
    def __init__(self, cfg, n_user, n_entity, n_relation, adj_entity, adj_relation, path_list_dict, name, eval_mode=False):
        super().__init__()
        self.cfg = cfg
        self.name = name
        self.n_user = n_user
        self.n_entity = n_entity
        self.n_relation = n_relation
        self.batch_size = cfg.optimize['batch_size']
        self.adj_entity = adj_entity
        self.adj_relation = adj_relation
        self.path_list_dict = path_list_dict
        self.eval_mode = eval_mode

        self.user_emb_matrix = nn.Embedding(n_user, cfg.emb_dim)
        self.entity_emb_matrix = nn.Embedding(n_entity, cfg.emb_dim)
        self.relation_emb_matrix = nn.Embedding(n_relation, cfg.emb_dim)

        self.aggregators = nn.ModuleList([
            SumAggregatorWithDropout(cfg.emb_dim, cfg.dropout_rate, activation=nn.Tanh(), cfg=cfg)
            if i == cfg.n_iter - 1 else
            SumAggregatorWithDropout(cfg.emb_dim, cfg.dropout_rate, activation=nn.LeakyReLU(), cfg=cfg)
            for i in range(cfg.n_iter)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.user_emb_matrix.weight)
        nn.init.xavier_uniform_(self.entity_emb_matrix.weight)
        nn.init.xavier_uniform_(self.relation_emb_matrix.weight)

    def forward(self, user_indices, item_indices):
        user_embeds = self.user_emb_matrix(user_indices)
        item_embeds = self.get_item_embeddings(item_indices)
        scores = (user_embeds * item_embeds).sum(dim=-1)
        return scores

    def get_item_embeddings(self, item_indices):
        entities = [item_indices]
        relations = []
        for i in range(self.cfg.n_iter):
            neighbor_entities = self.adj_entity[entities[-1]].view(item_indices.size(0), -1)
            neighbor_relations = self.adj_relation[entities[-1]].view(item_indices.size(0), -1)
            entities.append(neighbor_entities)
            relations.append(neighbor_relations)

        entity_vectors = [self.entity_emb_matrix(e) for e in entities]
        relation_vectors = [self.relation_emb_matrix(r) for r in relations]

        for i in range(self.cfg.n_iter):
            new_vectors = []
            for hop in range(self.cfg.n_iter - i):
                batch_size, neighbor_size, emb_dim = entity_vectors[hop+1].size(0), entity_vectors[hop+1].size(1) // self.cfg.model['neighbor_sample_size'], entity_vectors[hop+1].size(2)
                neighbor_vecs = entity_vectors[hop+1].view(batch_size, neighbor_size, self.cfg.model['neighbor_sample_size'], emb_dim)
                relation_vecs = relation_vectors[hop].view(batch_size, neighbor_size, self.cfg.model['neighbor_sample_size'], emb_dim)
                vector = self.aggregators[i](
                    self_vectors=entity_vectors[hop],
                    neighbor_vectors=neighbor_vecs,
                    neighbor_relations=relation_vecs,
                    user_embeddings=None  # optional
                )
                new_vectors.append(vector)
            entity_vectors = new_vectors

        return entity_vectors[0]

In [132]:
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0

    for batch in dataloader:
        user, pos_item, neg_item, pseudo_item = batch
        user = user.to(device)
        pos_item = pos_item.to(device)
        neg_item = neg_item.to(device)
        pseudo_item = pseudo_item.to(device)

        optimizer.zero_grad()

        pos_scores = model(user, pos_item)
        neg_scores = model(user, neg_item)
        pseudo_scores = model(user, pseudo_item)

        loss = kgpl_loss(pos_scores, neg_scores, pseudo_scores)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [133]:
train_dataset = KGPL_Dataset.datasets['train']

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,  # optional for faster loading
    pin_memory=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = KGPLStudent(model_cfg, 1872, 9366, 3846, train_dataset.adj_entity, train_dataset.adj_relation, train_dataset.path_list_dict, name='student').to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=model_cfg.optimize['lr'])

# Train
for epoch in tqdm(range(10)):
    loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch} | Loss: {loss:.4f}")

  return user, torch.tensor(pos_item), torch.tensor(neg_item), torch.tensor(pseudo_label)
  return user, torch.tensor(pos_item), torch.tensor(neg_item), torch.tensor(pseudo_label)
  return user, torch.tensor(pos_item), torch.tensor(neg_item), torch.tensor(pseudo_label)
  return user, torch.tensor(pos_item), torch.tensor(neg_item), torch.tensor(pseudo_label)


Something went wrong.
Something went wrong.


  0%|          | 0/10 [00:00<?, ?it/s]


IndexError: Caught IndexError in DataLoader worker process 2.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "<ipython-input-126-d28979224a6c>", line 175, in __getitem__
    pos_item = self.sample_positive(user)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-126-d28979224a6c>", line 156, in sample_positive
    return random.choice(choice_seq)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/random.py", line 373, in choice
    raise IndexError('Cannot choose from an empty sequence')
IndexError: Cannot choose from an empty sequence


In [None]:
def _get_mini_batch_pl(users):
        """
        Create pseudo-labelled instances for users
        """
        pl_users, pl_items = [], []
        ind = 0
        cands, freq_F = self.item_freq # what's item_freq?
        while True:
            u = users[ind % len(users)]
            ind += 1
            if u in self.item_dist_dict and len(self.item_dist_dict[u][0]) != 0:
                udst, F = self.item_dist_dict[u]
                i = udst[np.searchsorted(F, random.random())]
            else:
                while True:
                    i = cands[np.searchsorted(freq_F, random.random())]
                    if i not in self.user_seed_dict[u]:
                        break
            pl_users.append(u)
            pl_items.append(i)
            if len(pl_users) == len(users):
                break

        pl_users_pad = list(pl_users) + [0] * (self.batch_size - len(pl_users))
        pl_items_pad = list(pl_items) + [0] * (self.batch_size - len(pl_items))
        pl_labels_pad = sess.run(
            self.scores_normalized,
            feed_dict={
                self.user_indices: pl_users_pad,
                self.item_indices: pl_items_pad,
                self.dropout_rate: 0.0,
            },
        )
        pl_users = pl_users_pad[: len(pl_users)]
        pl_items = pl_items_pad[: len(pl_items)]
        pl_labels = pl_labels_pad[: len(pl_users)]
        return pl_users, pl_items, pl_labels

In [None]:
def get_neighbors(self, seeds):
        seeds = tf.expand_dims(seeds, axis=1)
        entities = [seeds]
        relations = []
        for i in range(self.cfg.n_iter):
            neighbor_entities = tf.reshape(
                tf.gather(self.adj_entity, entities[i]), [self.batch_size, -1]
            )
            neighbor_relations = tf.reshape(
                tf.gather(self.adj_relation, entities[i]), [self.batch_size, -1]
            )
            entities.append(neighbor_entities)
            relations.append(neighbor_relations)
        return entities, relations