### Format Knowledge-Graph Embeddings for Hopwise `dataset.get_preload_weight()` function
This notebook shows you how you can format kge methods embeddings to be loaded with `dataset.get_preload_weight`


📚 [Load Pretrained Embedding Documentation](https://recbole.io/docs/user_guide/usage/load_pretrained_embedding.html)

**Load Libraries**

In [1]:
import torch
import os
import pandas as pd
import numpy as np
import torch.nn as nn
from hopwise.data import create_dataset

### Load Checkpoint

In [2]:
checkpoint_name = 'saved/TransE-Mar-21-2025_10-23-35.pth'

In [3]:
checkpoint = torch.load(checkpoint_name)

  checkpoint = torch.load(checkpoint_name)
  from .autonotebook import tqdm as notebook_tqdm


**The Embeddings detected are**

In [4]:
checkpoint['state_dict'].keys()

odict_keys(['user_embedding.weight', 'entity_embedding.weight', 'relation_embedding.weight'])

**Do you want to exclude some embeddings?**

In [5]:
excluded = ['relation_bias_embedding.weight']

**The Dataset detected is**

In [6]:
dataset_name = checkpoint['config']['dataset']
dataset_name

'ml-100k'

**The Dataset folder detected is**

In [7]:
data_path = checkpoint['config']['data_path']
data_path

'/home/asoccol/hopwise/hopwise/config/../dataset_example/ml-100k'

**Create the mappings between embedding and original entity/relation/user**

- Users have a mapping 1-1 so we don't need a mapping.

- We suppose that indexing starts at 1. (tipically 0 is reserved for [PAD])

In [8]:
dataset = create_dataset(checkpoint['config'])

In [9]:
dataset.field2token_id['tail_id']

{'[PAD]': 0,
 'm.028r88': 1,
 'm.05xss5': 2,
 'm.0pb33': 3,
 'm.04j34g3': 4,
 'm.071v0j': 5,
 'm.0320gs': 6,
 'm.03hjw08': 7,
 'm.0599rp': 8,
 'm.06wmvw': 9,
 'm.07chp9': 10,
 'm.07l5v6': 11,
 'm.047nzpn': 12,
 'm.047q4xt': 13,
 'm.02qf7sl': 14,
 'm.02dpl9': 15,
 'm.060r16': 16,
 'm.0838sl': 17,
 'm.09zxs9': 18,
 'm.0pvms': 19,
 'm.0d0hvd': 20,
 'm.02pny3k': 21,
 'm.0bq6zl': 22,
 'm.0676l_': 23,
 'm.0dy575': 24,
 'm.0dyb1': 25,
 'm.01dybc': 26,
 'm.02lgqm': 27,
 'm.0byshz5': 28,
 'm.051xs93': 29,
 'm.0fpdlt': 30,
 'm.01z9n6': 31,
 'm.0cts7b': 32,
 'm.0n83s': 33,
 'm.040mt6': 34,
 'm.05wt50': 35,
 'm.01k5y0': 36,
 'm.02rv7cr': 37,
 'm.06sgvs': 38,
 'm.07j6w': 39,
 'm.033hmj': 40,
 'm.0bpbgy': 41,
 'm.0bbqb58': 42,
 'm.03m3y1s': 43,
 'm.0287z_j': 44,
 'm.03m4mm7': 45,
 'm.059xf1': 46,
 'm.033_kx': 47,
 'm.03vny7': 48,
 'm.05r1q4': 49,
 'm.03_w9b': 50,
 'm.08qb0q': 51,
 'm.05m55b': 52,
 'm.0344xk': 53,
 'm.032_76': 54,
 'm.0c03cl_': 55,
 'm.0dtknf': 56,
 'm.0265qmj': 57,
 'm.08ljr5': 58,


In [10]:
# create the reverse mapping
uid2token = {id: token for token,id in dataset.field2token_id['user_id'].items()}
print(uid2token)
eid2token = {id: token for token, id in dataset.field2token_id['tail_id'].items()}
print(eid2token)
rid2token = {id: token for token, id in dataset.field2token_id['relation_id'].items()}
print(rid2token)

{0: '[PAD]', 1: '196', 2: '186', 3: '22', 4: '244', 5: '166', 6: '298', 7: '115', 8: '253', 9: '305', 10: '6', 11: '62', 12: '286', 13: '200', 14: '210', 15: '224', 16: '303', 17: '122', 18: '194', 19: '291', 20: '234', 21: '119', 22: '167', 23: '299', 24: '308', 25: '95', 26: '38', 27: '102', 28: '63', 29: '160', 30: '50', 31: '301', 32: '225', 33: '290', 34: '97', 35: '157', 36: '181', 37: '278', 38: '276', 39: '7', 40: '10', 41: '284', 42: '201', 43: '287', 44: '246', 45: '242', 46: '249', 47: '99', 48: '178', 49: '251', 50: '81', 51: '260', 52: '25', 53: '59', 54: '72', 55: '87', 56: '42', 57: '292', 58: '20', 59: '13', 60: '138', 61: '60', 62: '57', 63: '223', 64: '189', 65: '243', 66: '92', 67: '241', 68: '254', 69: '293', 70: '127', 71: '222', 72: '267', 73: '11', 74: '8', 75: '162', 76: '279', 77: '145', 78: '28', 79: '135', 80: '32', 81: '90', 82: '216', 83: '250', 84: '271', 85: '265', 86: '198', 87: '168', 88: '110', 89: '58', 90: '237', 91: '94', 92: '128', 93: '44', 94: '2

In [11]:
# # add dummy relation, check kge code
# rid2token[len(rid2token)] = 'ui_dummy_relation'

In [12]:
assert (len(eid2token.keys()) == checkpoint['state_dict']['entity_embedding.weight'].shape[0])
assert (len(rid2token.keys()) == checkpoint['state_dict']['relation_embedding.weight'].shape[0])

*if the assertion check fails, make sure that you've trained the kge without adding dummy relations/entities explicitly when creating relation/entity embeddings!*

### Create the new embeddings

In [13]:
def format_embedding(weight, columns, emb_type):
    weight = weight.detach().cpu().numpy()
    new_emb_dict = {columns[0]: list(), 
                    columns[1]: list() }
    
    if emb_type == 'entity':
        mapping = eid2token    
    elif emb_type == 'relation':
        mapping = rid2token
    elif emb_type == 'user':
        mapping = uid2token
        
    # Create index
    new_emb_dict[columns[0]] = [mapping[id] if mapping is not None else id for id in range(1,len(weight))]

    # Create embedding
    new_emb_dict[columns[1]] = [" ".join(f"{x}" for x in row) for row in weight[1:]]
    
    filename = f'{dataset_name}.{emb_type}emb'
    df = pd.DataFrame(new_emb_dict)
    print(f"[+] Saving the new {dataset_name} {columns[0]} embedding in {data_path}/{filename}!")
    df.to_csv(os.path.join(data_path,filename), sep='\t',index=False)

In [14]:
for emb_name, emb in checkpoint['state_dict'].items():
    if emb_name in excluded:
        continue
    # What is? Entity? User? Relation? Item? 
    emb_type = emb_name.split("_")[0]
    # Create the new embedding file columns
    columns = [f'{emb_type}id:token', f'{emb_type}_embedding:float_seq']
    print(f"[+] Formatting {emb_name} with columns {columns}")
    format_embedding(emb, columns, emb_type)

[+] Formatting user_embedding.weight with columns ['userid:token', 'user_embedding:float_seq']
[+] Saving the new ml-100k userid:token embedding in /home/asoccol/hopwise/hopwise/config/../dataset_example/ml-100k/ml-100k.useremb!
[+] Formatting entity_embedding.weight with columns ['entityid:token', 'entity_embedding:float_seq']
[+] Saving the new ml-100k entityid:token embedding in /home/asoccol/hopwise/hopwise/config/../dataset_example/ml-100k/ml-100k.entityemb!
[+] Formatting relation_embedding.weight with columns ['relationid:token', 'relation_embedding:float_seq']
[+] Saving the new ml-100k relationid:token embedding in /home/asoccol/hopwise/hopwise/config/../dataset_example/ml-100k/ml-100k.relationemb!


### Next?

Now, in the dataset folder there are these file

In [15]:
os.listdir(data_path)

['ml-100k.user',
 'ml-100k.relationemb',
 'ml-100k.item',
 'ml-100k.inter',
 'ml-100k.useremb',
 'ml-100k.link',
 'ml-100k.entityemb',
 'ml-100k.kg']

**We want to make sure that the dataset configuration is ok.**

Suppose that the output of the format embedding phase is:

```text
    [+] Formatting user_embedding.weight with columns ['userid:token', 'user_embedding:float_seq']
    [+] Saving the new ml-1m userid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.useremb!
    [+] Formatting entity_embedding.weight with columns ['entityid:token', 'entity_embedding:float_seq']
    [+] Saving the new ml-1m entityid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.entityemb!
    [+] Formatting relation_embedding.weight with columns ['relationid:token', 'relation_embedding:float_seq']
    [+] Saving the new ml-1m relationid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.relationemb!
```

Then, you should go to the dataset configuration file (in our case is in `hopwise/properties/dataset/ml-1m.yaml`) and add the new files to be loaded


```text
    additional_feat_suffix: [useremb, entityemb, relationemb]  
    load_col:                                                  
        useremb: [userid, user_embedding]
        entityemb: [entityid, entity_embedding]
        relationemb: [relationid, relation_embedding]
    
    alias_of_user_id: [userid]
    alias_of_entity_id: [entityid]
    alias_of_relation_id: [relationid]
    
    preload_weight:
      userid: user_embedding
      entityid: entity_embedding
      relationid: relation_embedding

```



### The end

Now in your code you should be able to access to pretrained embeddings in your model through:

*Torch*
```python
    pretrained_user_emb = dataset.get_preload_weight('userid')
    pretrained_entity_emb = dataset.get_preload_weight('entityid')
    pretrained_relation_emb = dataset.get_preload_weight('relationid')
    
    self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb))
    self.entity_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_entity_emb))
    self.relation_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_relation_emb))
```

*Numpy*:
```python
    self.pretrained_user_emb = dataset.get_preload_weight('userid')
    self.entity_embedding = dataset.get_preload_weight('entityid')
    self.relation_embedding = dataset.get_preload_weight('relationid')
```


