**Nicole Guo**

**June 12th, 2022**

# KGAT Demo

Since the was no available online implementation of KGAT-SR, it was decided that I would instead attempt to reproduce KGAT, and demonstrate how to run the program within this notebook. This code is not the exact implementation of KGAT that the authors created. Rather, this is a re-implementation using PyTorch. The original code for the paper can be found [here](https://github.com/xiangwang1223/knowledge_graph_attention_network) and the PyTorch implementation of which this project is forked from can be found [here](https://github.com/LunaBlack/KGAT-pytorch).

To get this program to run, I simply forked the [GitHub repository](https://github.com/LunaBlack/KGAT-pytorch), and created an Anaconda environment to use as a kernel. The required packages can likewise be seen in the `README.md` file [here](https://github.com/LunaBlack/KGAT-pytorch/blob/master/README.md).

# Code breakdown
The main files I will be referring to and breaking down are located in [KGAT.py](./model/KGAT.py) and [main_kgat.py](./main_kgat.py). 

## `KGAT.py`
![](media/illustration_kgat_model.PNG)

The `KGAT.py` file contains the definition of the model itself using references from the paper as needed. This can be found throughout the file itself. It is broken down into two main classes: the `Aggregator` class and the `KGAT` class. I'll start by breaking down the `KGAT` class

### `KGAT` class
The `KGAT` class serves as the primary module for implementing the different equations relating to the paper. Each function is associated with it's corresponding equation in the original paper.
- **`calc_cf_embeddings`** - this refers to section 3.3, and works "to concatenate the representations at each step into a single vector" (5). It corresponds to Equation 11.

```python
    def calc_cf_embeddings(self):
        ego_embed = self.entity_user_embed.weight
        all_embed = [ego_embed]

        for idx, layer in enumerate(self.aggregator_layers):
            ego_embed = layer(ego_embed, self.A_in)
            norm_embed = F.normalize(ego_embed, p=2, dim=1)
            all_embed.append(norm_embed)

        # Equation (11)
        all_embed = torch.cat(all_embed, dim=1)         # (n_users + n_entities, concat_dim)
        return all_embed
```



- **`calc_kg_loss`** - this function calculates the loss for the Knowledge graph (known as the "plausibiliy score") as defined in Equation 1 and Equation 2. Using the TransR method, the relative order between valid triplets and broken ones are learned during the training of the model and calculates a pairwise ranking loss (3-4), which is one half of the total loss of the GNN.

```python
    def calc_kg_loss(self, h, r, pos_t, neg_t):
        """
        h:      (kg_batch_size)
        r:      (kg_batch_size)
        pos_t:  (kg_batch_size)
        neg_t:  (kg_batch_size)
        """
        r_embed = self.relation_embed(r)                                                # (kg_batch_size, relation_dim)
        W_r = self.trans_M[r]                                                           # (kg_batch_size, embed_dim, relation_dim)

        h_embed = self.entity_user_embed(h)                                             # (kg_batch_size, embed_dim)
        pos_t_embed = self.entity_user_embed(pos_t)                                     # (kg_batch_size, embed_dim)
        neg_t_embed = self.entity_user_embed(neg_t)                                     # (kg_batch_size, embed_dim)

        r_mul_h = torch.bmm(h_embed.unsqueeze(1), W_r).squeeze(1)                       # (kg_batch_size, relation_dim)
        r_mul_pos_t = torch.bmm(pos_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)
        r_mul_neg_t = torch.bmm(neg_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)

        # Equation (1)
        pos_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_pos_t, 2), dim=1)     # (kg_batch_size)
        neg_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_neg_t, 2), dim=1)     # (kg_batch_size)

        # Equation (2)
        # kg_loss = F.softplus(pos_score - neg_score)
        kg_loss = (-1.0) * F.logsigmoid(neg_score - pos_score)
        kg_loss = torch.mean(kg_loss)

        l2_loss = _L2_loss_mean(r_mul_h) + _L2_loss_mean(r_embed) + _L2_loss_mean(r_mul_pos_t) + _L2_loss_mean(r_mul_neg_t)
        loss = kg_loss + self.kg_l2loss_lambda * l2_loss
        return loss
```



- **`calc_cf_loss`** - this function calculates the collaberative filtering (CF) loss as defined in Equation 13. This is the other half of the total oss of the GNN.

```python
    def calc_cf_loss(self, user_ids, item_pos_ids, item_neg_ids):
        """
        user_ids:       (cf_batch_size)
        item_pos_ids:   (cf_batch_size)
        item_neg_ids:   (cf_batch_size)
        """
        all_embed = self.calc_cf_embeddings()                       # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                            # (cf_batch_size, concat_dim)
        item_pos_embed = all_embed[item_pos_ids]                    # (cf_batch_size, concat_dim)
        item_neg_embed = all_embed[item_neg_ids]                    # (cf_batch_size, concat_dim)

        # Equation (12)
        pos_score = torch.sum(user_embed * item_pos_embed, dim=1)   # (cf_batch_size)
        neg_score = torch.sum(user_embed * item_neg_embed, dim=1)   # (cf_batch_size)

        # Equation (13)
        # cf_loss = F.softplus(neg_score - pos_score)
        cf_loss = (-1.0) * F.logsigmoid(pos_score - neg_score)
        cf_loss = torch.mean(cf_loss)

        l2_loss = _L2_loss_mean(user_embed) + _L2_loss_mean(item_pos_embed) + _L2_loss_mean(item_neg_embed)
        loss = cf_loss + self.cf_l2loss_lambda * l2_loss
        return loss
```



- **`update_attention_batch`** - this refers to the **Knowledge-aware attention mechanism**, and corresponds to Equation 4. This "makes the attention score dependent on the distance between head and tail entity in the specific relation space allowing for more information to be propagated between closer entities" (4).

```python
    def update_attention_batch(self, h_list, t_list, r_idx):
        r_embed = self.relation_embed.weight[r_idx]
        W_r = self.trans_M[r_idx]

        h_embed = self.entity_user_embed.weight[h_list]
        t_embed = self.entity_user_embed.weight[t_list]

        # Equation (4)
        r_mul_h = torch.matmul(h_embed, W_r)
        r_mul_t = torch.matmul(t_embed, W_r)
        v_list = torch.sum(r_mul_t * torch.tanh(r_mul_h + r_embed), dim=1)
        return v_list

```



- **`update_attention`** - this function updates the entire network by using normalized weights across all triplets connected to the head node using the SoftMax function. It corresponds to Equation 5.

```python
    def update_attention(self, h_list, t_list, r_list, relations):
        device = self.A_in.device

        rows = []
        cols = []
        values = []

        for r_idx in relations:
            index_list = torch.where(r_list == r_idx)
            batch_h_list = h_list[index_list]
            batch_t_list = t_list[index_list]

            batch_v_list = self.update_attention_batch(batch_h_list, batch_t_list, r_idx)
            rows.append(batch_h_list)
            cols.append(batch_t_list)
            values.append(batch_v_list)

        rows = torch.cat(rows)
        cols = torch.cat(cols)
        values = torch.cat(values)

        indices = torch.stack([rows, cols])
        shape = self.A_in.shape
        A_in = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

        # Equation (5)
        A_in = torch.sparse.softmax(A_in.cpu(), dim=1)
        self.A_in.data = A_in.to(device)
```


- **`calc_score`** - This is the prediction function, which uses the inner product of user and item representations to form a prediction (5). This corresponds to Equation 12. 

```python
    def calc_score(self, user_ids, item_ids):
        """
        user_ids:  (n_users)
        item_ids:  (n_items)
        """
        all_embed = self.calc_cf_embeddings()           # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                # (n_users, concat_dim)
        item_embed = all_embed[item_ids]                # (n_items, concat_dim)

        # Equation (12)
        cf_score = torch.matmul(user_embed, item_embed.transpose(0, 1))    # (n_users, n_items)
        return cf_score
```



### `Aggregator` class
The `Aggregator` class corresponds to section 3.2 in the original KGAT paper. This section discusses three different kinds of aggregators, the **GCN Aggregator** (Equation 6), **Bi-Interation Aggregator** (Equation 8), and **GraphSage Aggregator** (Equation 7). Likewise, the different `aggregator_type`'s that are available for the user to use is `gcn`, `bi-interaction`, and `graphsage`, which correspond to the each equation respectively. 

```python
def __init__(self, in_dim, out_dim, dropout, aggregator_type):
    super(Aggregator, self).__init__()
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.dropout = dropout
    self.aggregator_type = aggregator_type

    self.message_dropout = nn.Dropout(dropout)
    self.activation = nn.LeakyReLU()

    if self.aggregator_type == 'gcn':
        self.linear = nn.Linear(self.in_dim, self.out_dim)       # W in Equation (6)
        nn.init.xavier_uniform_(self.linear.weight)

    elif self.aggregator_type == 'graphsage':
        self.linear = nn.Linear(self.in_dim * 2, self.out_dim)   # W in Equation (7)
        nn.init.xavier_uniform_(self.linear.weight)

    elif self.aggregator_type == 'bi-interaction':
        self.linear1 = nn.Linear(self.in_dim, self.out_dim)      # W1 in Equation (8)
        self.linear2 = nn.Linear(self.in_dim, self.out_dim)      # W2 in Equation (8)
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.xavier_uniform_(self.linear2.weight)

    else:
        raise NotImplementedError
```




## `main_kgat.py`
`main_kgat.py` serves as the training script for the KGAT model. The main functions that we will focus on is the `train` and `evaluate` functions. The **`train`** function sets up the logging functionality, loads the data, creates the model and optimizers, and then begins training. At the end of training, the best iteration of the model's weights are saved to a `.pth` file. The **`evaluate`** function calculates the metrics as defined in the paper, in addition to precision. The CF (collaborative filtering) score is also returned, to later be combined with the KG (knowledge graph) score to calculate the total loss of the KGAT model.

Running this as a script is easy; running `python main_kgat.py` will train the KGAT model using the default parameters as defined in [`parser_kgat.py`](argparsers/parser_kgat.py). These parameters can be changed by adding in command line arguments. Here are all available command line parameters:
```
usage: main_kgat.py [-h] [--seed SEED] [--data_name [DATA_NAME]] [--data_dir [DATA_DIR]] [--use_pretrain USE_PRETRAIN] [--pretrain_embedding_dir [PRETRAIN_EMBEDDING_DIR]] [--pretrain_model_path [PRETRAIN_MODEL_PATH]]
                    [--cf_batch_size CF_BATCH_SIZE] [--kg_batch_size KG_BATCH_SIZE] [--test_batch_size TEST_BATCH_SIZE] [--embed_dim EMBED_DIM] [--relation_dim RELATION_DIM] [--laplacian_type LAPLACIAN_TYPE]
                    [--aggregation_type AGGREGATION_TYPE] [--conv_dim_list [CONV_DIM_LIST]] [--mess_dropout [MESS_DROPOUT]] [--kg_l2loss_lambda KG_L2LOSS_LAMBDA] [--cf_l2loss_lambda CF_L2LOSS_LAMBDA] [--lr LR] [--n_epoch N_EPOCH]
                    [--stopping_steps STOPPING_STEPS] [--cf_print_every CF_PRINT_EVERY] [--kg_print_every KG_PRINT_EVERY] [--evaluate_every EVALUATE_EVERY] [--Ks [KS]]

Run KGAT.

optional arguments:
  -h, --help            show this help message and exit
  --seed SEED           Random seed.
  --data_name [DATA_NAME]
                        Choose a dataset from {yelp2018, last-fm, amazon-book}
  --data_dir [DATA_DIR]
                        Input data path.
  --use_pretrain USE_PRETRAIN
                        0: No pretrain, 1: Pretrain with the learned embeddings, 2: Pretrain with stored model.
  --pretrain_embedding_dir [PRETRAIN_EMBEDDING_DIR]
                        Path of learned embeddings.
  --pretrain_model_path [PRETRAIN_MODEL_PATH]
                        Path of stored model.
  --cf_batch_size CF_BATCH_SIZE
                        CF batch size.
  --kg_batch_size KG_BATCH_SIZE
                        KG batch size.
  --test_batch_size TEST_BATCH_SIZE
                        Test batch size (the user number to test every batch).
  --embed_dim EMBED_DIM
                        User / entity Embedding size.
  --relation_dim RELATION_DIM
                        Relation Embedding size.
  --laplacian_type LAPLACIAN_TYPE
                        Specify the type of the adjacency (laplacian) matrix from {symmetric, random-walk}.
  --aggregation_type AGGREGATION_TYPE
                        Specify the type of the aggregation layer from {gcn, graphsage, bi-interaction}.
  --conv_dim_list [CONV_DIM_LIST]
                        Output sizes of every aggregation layer.
  --mess_dropout [MESS_DROPOUT]
                        Dropout probability w.r.t. message dropout for each deep layer. 0: no dropout.
  --kg_l2loss_lambda KG_L2LOSS_LAMBDA
                        Lambda when calculating KG l2 loss.
  --cf_l2loss_lambda CF_L2LOSS_LAMBDA
                        Lambda when calculating CF l2 loss.
  --lr LR               Learning rate.
  --n_epoch N_EPOCH     Number of epoch.
  --stopping_steps STOPPING_STEPS
                        Number of epoch for early stopping
  --cf_print_every CF_PRINT_EVERY
                        Iter interval of printing CF loss.
  --kg_print_every KG_PRINT_EVERY
                        Iter interval of printing KG loss.
  --evaluate_every EVALUATE_EVERY
                        Epoch interval of evaluating CF.
  --Ks [KS]             Calculate metric@K when evaluating.
```


# Experiment description
The original KGAT paper investigates a number of research questions thoroughly, however for this demonstration I will be attempting one problem:
- *Is it possible to reproduce the results as presented in the paper? Additionally, how does the original paper and this PyTorch implementation look when compared with the results achieved from our local run?*

To this end, tests were run using the same configuration as detailed in the paper in three different environments. The metrics compared in the paper (recall@20, ndcg@20) in addition to precision@20 were also calculated to be compared. The dataset used is [`amazon-book`](http://jmcauley.ucsd.edu/data/amazon/), as it was the only dataset for which all metrics were reported on regarding the original paper and the PyTorch implementation. A demonstration to run the program before looks like this:
```batch
python main_kgat.py --data_name amazon-book
```

```
2022-06-12 12:26:02,454 - root - INFO - Namespace(Ks='[20, 40, 60, 80, 100]', aggregation_type='bi-interaction', cf_batch_size=1024, cf_l2loss_lambda=1e-05, cf_print_every=1, conv_dim_list='[64, 32, 16]', data_dir='datasets/', data_name='amazon-book', embed_dim=64, evaluate_every=10, kg_batch_size=2048, kg_l2loss_lambda=1e-05, kg_print_every=1, laplacian_type='random-walk', lr=0.0001, mess_dropout='[0.1, 0.1, 0.1]', n_epoch=5, pretrain_embedding_dir='datasets/pretrain/', pretrain_model_path='trained_model/model.pth', relation_dim=64, save_dir='trained_model/KGAT/amazon-book/embed-dim64_relation-dim64_random-walk_bi-interaction_64-32-16_lr0.0001_pretrain1/', seed=2019, stopping_steps=10, test_batch_size=10000, use_pretrain=1)
2022-06-12 12:29:09,670 - root - INFO - n_users:           70679
2022-06-12 12:29:09,671 - root - INFO - n_items:           24915
2022-06-12 12:29:09,671 - root - INFO - n_entities:        113487
2022-06-12 12:29:09,671 - root - INFO - n_users_entities:  184166
2022-06-12 12:29:09,671 - root - INFO - n_relations:       80
2022-06-12 12:29:09,671 - root - INFO - n_h_list:          6420520
2022-06-12 12:29:09,672 - root - INFO - n_t_list:          6420520
2022-06-12 12:29:09,672 - root - INFO - n_r_list:          6420520
2022-06-12 12:29:09,672 - root - INFO - n_cf_train:        652514
2022-06-12 12:29:09,672 - root - INFO - n_cf_test:         193920
2022-06-12 12:29:09,672 - root - INFO - n_kg_train:        6420520
2022-06-12 12:29:14,953 - root - INFO - KGAT(
  (entity_user_embed): Embedding(184166, 64)
  (relation_embed): Embedding(80, 64)
  (aggregator_layers): ModuleList(
    (0): Aggregator(
      (message_dropout): Dropout(p=0.1, inplace=False)
      (activation): LeakyReLU(negative_slope=0.01)
      (linear1): Linear(in_features=64, out_features=64, bias=True)
      (linear2): Linear(in_features=64, out_features=64, bias=True)
    )
    (1): Aggregator(
      (message_dropout): Dropout(p=0.1, inplace=False)
      (activation): LeakyReLU(negative_slope=0.01)
      (linear1): Linear(in_features=64, out_features=32, bias=True)
      (linear2): Linear(in_features=64, out_features=32, bias=True)
    )
    (2): Aggregator(
      (message_dropout): Dropout(p=0.1, inplace=False)
      (activation): LeakyReLU(negative_slope=0.01)
      (linear1): Linear(in_features=32, out_features=16, bias=True)
      (linear2): Linear(in_features=32, out_features=16, bias=True)
    )
  )
)
2022-06-12 12:29:19,025 - root - INFO - CF Training: Epoch 0001 Iter 0001 / 0638 | Time 4.1s | Iter Loss 0.0285 | Iter Mean Loss 0.0285
2022-06-12 12:29:22,951 - root - INFO - CF Training: Epoch 0001 Iter 0002 / 0638 | Time 3.9s | Iter Loss 0.0238 | Iter Mean Loss 0.0261
2022-06-12 12:29:26,886 - root - INFO - CF Training: Epoch 0001 Iter 0003 / 0638 | Time 3.9s | Iter Loss 0.0264 | Iter Mean Loss 0.0262
.
.
.
2022-06-12 16:41:43,259 - root - INFO - KG Training: Epoch 0005 Iter 3130 / 3136 | Time 0.2s | Iter Loss 0.0071 | Iter Mean Loss 0.0177
2022-06-12 16:41:43,449 - root - INFO - KG Training: Epoch 0005 Iter 3131 / 3136 | Time 0.2s | Iter Loss 0.0098 | Iter Mean Loss 0.0177
2022-06-12 16:41:43,622 - root - INFO - KG Training: Epoch 0005 Iter 3132 / 3136 | Time 0.2s | Iter Loss 0.0134 | Iter Mean Loss 0.0177
2022-06-12 16:41:43,829 - root - INFO - KG Training: Epoch 0005 Iter 3133 / 3136 | Time 0.2s | Iter Loss 0.0186 | Iter Mean Loss 0.0177
2022-06-12 16:41:44,011 - root - INFO - KG Training: Epoch 0005 Iter 3134 / 3136 | Time 0.2s | Iter Loss 0.0118 | Iter Mean Loss 0.0177
2022-06-12 16:41:44,204 - root - INFO - KG Training: Epoch 0005 Iter 3135 / 3136 | Time 0.2s | Iter Loss 0.0148 | Iter Mean Loss 0.0177
2022-06-12 16:41:44,394 - root - INFO - KG Training: Epoch 0005 Iter 3136 / 3136 | Time 0.2s | Iter Loss 0.0194 | Iter Mean Loss 0.0177
2022-06-12 16:41:44,394 - root - INFO - KG Training: Epoch 0005 Total Iter 3136 | Total Time 567.6s | Iter Mean Loss 0.0177
2022-06-12 16:41:47,868 - root - INFO - Update Attention: Epoch 0005 | Total Time 3.5s
2022-06-12 16:41:47,868 - root - INFO - CF + KG Training: Epoch 0005 | Total Time 2990.8s
```

# Results
The results weren't able to be completely reconstructed, however similar results were attained. After running the line above, a `.tsv` file is created storing the metric results for the entire run, using the best iteration of the model. I retrieve these results and display them below:


In [1]:
import pandas as pd

In [7]:
kgat_amazon_filepath = "trained_model/KGAT/amazon-book/embed-dim64_relation-dim64_random-walk_bi-interaction_64-32-16_lr0.0001_pretrain1/metrics.tsv"
kgat_amazon_df = pd.read_csv(kgat_amazon_filepath, sep="\t")
kgat_amazon_paper_metrics = ["epoch_idx", "precision@20", "recall@20", "ndcg@20"]
kgat_amazon_df = kgat_amazon_df[kgat_amazon_paper_metrics]
kgat_amazon_df.head()


Unnamed: 0,epoch_idx,precision@20,recall@20,ndcg@20
0,5.0,0.013484,0.127648,0.067181


That would result in this table breakdown:

| Implementation                    | Best Epoch  | Precision@20  | Recall@20 | NDCG@20   |
| --------------------------------- |-------------|---------------|-----------|-----------|
| Orig. Paper Implementation        |      /      |       /       |   0.1489  |  0.1006   |
| PyTorch  Implementation           |     280     |    0.0150     |   0.1440  |  0.0766   |
| Ours Re-implementation            |      5      |    0.0135     |   0.1276  |  0.0672   |

Being that our re-implementation works off of a pretrained model, seeing such good results despite having only 5 epochs to train makes sense. We can also see that the results are not perfectly aligned, however it seems clear that at a foundational level, this paper and the test environment is able to be reproduced.

# Future work
As this particular repository was incredibly flexible for different kinds of training configurations, I would have liked to explore that more. I specifically wanted to look at RQ2 from the paper, and focus on the different kinds of aggregation types as well as the different laplacian types. With the help of `wandb`, it would have been nice to run a grid search (as they did in the paper) using different configurations with the help of a VM as supplied by the univeristy. Oftne conducting a grid search with these two arguments would still require quite a bit of computation power, and thus for time and resource purposes, this was not explored in the notebook. This would have been helpful to do as this invesigation of parameters was an experiment done within the paper, and being able to recreate it in this environment would have been a fun addition to the assignment.

# References
- *Wang, X., He, X., Cao, Y., Liu, M., & Chua, T.-S. (2019a). KGAT. Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery Data Mining. doi:10.1145/3292500.3330989*

- *Lin, Y., Liu, Z., Sun, M., Liu, Y., & Zhu, X. (2015). Learning Entity and Relation Embeddings for Knowledge Graph Completion. Proceedings of the Twenty-Ninth AAAI Conference on Artificial Intelligence, 2181–2187. Austin, Texas: AAAI Press.*

- *Wang, X., He, X., Cao, Y., Liu, M., & Chua, T.-S. (2019b). KGAT: Knowledge Graph Attention Network for Recommendation. KDD, 950–958.*

- *Zhao, W. X., He, G., Yang, K., Dou, H.-J., Huang, J., Ouyang, S., & Wen, J.-R. (2019). KB4Rec: A Data Set for Linking Knowledge Bases with Recommender Systems. Data Intelligence, 1(2), 121–136. doi:10.1162/dint_a_00008*

- *Huang, J., Zhao, W. X., Dou, H.-J., Wen, J.-R., & Chang, E. Y. (2018). Improving Sequential Recommendation with Knowledge-Enhanced Memory Networks. The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval, SIGIR 2018, Ann Arbor, MI, USA, July 08-12, 2018, 505–514. doi:10.1145/3209978.3210017*

- *Zhao, W. X., Dou, H.-J., Zhao, Y., Dong, D., & Wen, J.-R. (2019). Neural Network Based Popularity Prediction by Linking Online Content with Knowledge Bases. Advances in Knowledge Discovery and Data Mining - 23rd Pacific-Asia Conference, PAKDD 2019, Macau, China, April 14-17, 2019, Proceedings, Part II, 16–28. doi:10.1007/978-3-030-16145-3_2*

- *Mining Implicit Entity Preference from User-Item Interaction Data for Knowledge Graph Completion via Adversarial Learning. (2020). Proceedings of The Web Conference.*