### Format Knowledge-Graph Embeddings for Hopwise `dataset.get_preload_weight()` function
This notebook shows you how you can format kge methods embeddings to be loaded with `dataset.get_preload_weight`


📚 [Load Pretrained Embedding Documentation](https://recbole.io/docs/user_guide/usage/load_pretrained_embedding.html)

**Load Libraries**

In [1]:
import torch
import os
import pandas as pd
import numpy as np
import torch.nn as nn
from hopwise.data import create_dataset

### Load Checkpoint

In [2]:
checkpoint_name = 'saved/TransE-Mar-24-2025_21-44-17.pth'

In [3]:
checkpoint = torch.load(checkpoint_name)

  checkpoint = torch.load(checkpoint_name)
  from .autonotebook import tqdm as notebook_tqdm


**The Embeddings detected are**

In [4]:
checkpoint['state_dict'].keys()

odict_keys(['user_embedding.weight', 'entity_embedding.weight', 'relation_embedding.weight'])

**Do you want to exclude some embeddings?**

In [5]:
excluded = ['relation_bias_embedding.weight']

**The Dataset detected is**

In [6]:
dataset_name = checkpoint['config']['dataset']
dataset_name

'lfm1m_small'

**The Dataset folder detected is**

In [7]:
data_path = checkpoint['config']['data_path']
data_path

'/home/recsysdatasets/lfm1m_small'

**Create the mappings between embedding and original entity/relation/user**

- Users have a mapping 1-1 so we don't need a mapping.

- We suppose that indexing starts at 1. (tipically 0 is reserved for [PAD])

In [8]:
dataset = create_dataset(checkpoint['config'])

In [9]:
dataset.field2token_id['tail_id']

{'[PAD]': 0,
 '1': 1,
 '8': 2,
 '9': 3,
 '11': 4,
 '14': 5,
 '15': 6,
 '17': 7,
 '23': 8,
 '25': 9,
 '26': 10,
 '31': 11,
 '32': 12,
 '33': 13,
 '34': 14,
 '35': 15,
 '39': 16,
 '40': 17,
 '41': 18,
 '42': 19,
 '43': 20,
 '44': 21,
 '47': 22,
 '48': 23,
 '50': 24,
 '61': 25,
 '62': 26,
 '63': 27,
 '64': 28,
 '68': 29,
 '69': 30,
 '70': 31,
 '71': 32,
 '72': 33,
 '73': 34,
 '74': 35,
 '79': 36,
 '80': 37,
 '83': 38,
 '18': 39,
 '86': 40,
 '91': 41,
 '92': 42,
 '94': 43,
 '95': 44,
 '96': 45,
 '97': 46,
 '98': 47,
 '100': 48,
 '102': 49,
 '103': 50,
 '104': 51,
 '105': 52,
 '106': 53,
 '108': 54,
 '109': 55,
 '110': 56,
 '111': 57,
 '113': 58,
 '118': 59,
 '122': 60,
 '123': 61,
 '124': 62,
 '125': 63,
 '126': 64,
 '131': 65,
 '136': 66,
 '140': 67,
 '138': 68,
 '144': 69,
 '146': 70,
 '148': 71,
 '149': 72,
 '150': 73,
 '151': 74,
 '153': 75,
 '154': 76,
 '155': 77,
 '157': 78,
 '159': 79,
 '160': 80,
 '161': 81,
 '183': 82,
 '188': 83,
 '190': 84,
 '191': 85,
 '192': 86,
 '204': 87,
 '

In [10]:
# create the reverse mapping
uid2token = {id: token for token,id in dataset.field2token_id['user_id'].items()}
print(uid2token)
eid2token = {id: token for token, id in dataset.field2token_id['tail_id'].items()}
print(eid2token)
rid2token = {id: token for token, id in dataset.field2token_id['relation_id'].items()}
print(rid2token)

{0: '[PAD]', 1: '21072247', 2: '4630834', 3: '7575716', 4: '8832481', 5: '8049814', 6: '7911354', 7: '11819149', 8: '16974766', 9: '2472606', 10: '5544243', 11: '39683363', 12: '19192269', 13: '49362190', 14: '5575276', 15: '2542575', 16: '21897295', 17: '46440568', 18: '33298919', 19: '17986707', 20: '1054946', 21: '10144757', 22: '12280213', 23: '5426684', 24: '3323773', 25: '2226960', 26: '18556633', 27: '20917024', 28: '41156258', 29: '20504124', 30: '47769684', 31: '3664659', 32: '17704917', 33: '36507594', 34: '7138659', 35: '6992664', 36: '11164940', 37: '45329671', 38: '10338243', 39: '9024412', 40: '2235610', 41: '20306549', 42: '33680615', 43: '10676255', 44: '17323450', 45: '5332001', 46: '49313892', 47: '14975516', 48: '12438107', 49: '8176940', 50: '16976956', 51: '27085235', 52: '39372961', 53: '21136659', 54: '40416260', 55: '17005170', 56: '10137159', 57: '3842093', 58: '10452108', 59: '34088988', 60: '21918137', 61: '12043444', 62: '11205295', 63: '10205459', 64: '4628

In [11]:
# # add dummy relation, check kge code
# rid2token[len(rid2token)] = 'ui_dummy_relation'

In [12]:
assert (len(eid2token.keys()) == checkpoint['state_dict']['entity_embedding.weight'].shape[0])
assert (len(rid2token.keys()) == checkpoint['state_dict']['relation_embedding.weight'].shape[0])

*if the assertion check fails, make sure that you've trained the kge without adding dummy relations/entities explicitly when creating relation/entity embeddings!*

### Create the new embeddings

In [13]:
def format_embedding(weight, columns, emb_type):
    weight = weight.detach().cpu().numpy()
    new_emb_dict = {columns[0]: list(), 
                    columns[1]: list() }
    
    if emb_type == 'entity':
        mapping = eid2token    
    elif emb_type == 'relation':
        mapping = rid2token
    elif emb_type == 'user':
        mapping = uid2token
        
    # Create index
    new_emb_dict[columns[0]] = [mapping[id] if mapping is not None else id for id in range(1,len(weight))]

    # Create embedding
    new_emb_dict[columns[1]] = [" ".join(f"{x}" for x in row) for row in weight[1:]]
    
    filename = f'{dataset_name}.{emb_type}emb'
    df = pd.DataFrame(new_emb_dict)
    print(f"[+] Saving the new {dataset_name} {columns[0]} embedding in {data_path}/{filename}!")
    df.to_csv(os.path.join(data_path,filename), sep='\t',index=False)

In [14]:
for emb_name, emb in checkpoint['state_dict'].items():
    if emb_name in excluded:
        continue
    # What is? Entity? User? Relation? Item? 
    emb_type = emb_name.split("_")[0]
    # Create the new embedding file columns
    columns = [f'{emb_type}id:token', f'{emb_type}_embedding:float_seq']
    print(f"[+] Formatting {emb_name} with columns {columns}")
    format_embedding(emb, columns, emb_type)

[+] Formatting user_embedding.weight with columns ['userid:token', 'user_embedding:float_seq']
[+] Saving the new lfm1m_small userid:token embedding in /home/recsysdatasets/lfm1m_small/lfm1m_small.useremb!
[+] Formatting entity_embedding.weight with columns ['entityid:token', 'entity_embedding:float_seq']
[+] Saving the new lfm1m_small entityid:token embedding in /home/recsysdatasets/lfm1m_small/lfm1m_small.entityemb!
[+] Formatting relation_embedding.weight with columns ['relationid:token', 'relation_embedding:float_seq']
[+] Saving the new lfm1m_small relationid:token embedding in /home/recsysdatasets/lfm1m_small/lfm1m_small.relationemb!


### Next?

Now, in the dataset folder there are these file

In [15]:
os.listdir(data_path)

['lfm1m_small.link',
 'lfm1m_small.inter',
 'lfm1m_small.kg',
 'lfm1m_small.relationemb',
 'lfm1m_small.useremb',
 'lfm1m_small.entityemb']

**We want to make sure that the dataset configuration is ok.**

Suppose that the output of the format embedding phase is:

```text
    [+] Formatting user_embedding.weight with columns ['userid:token', 'user_embedding:float_seq']
    [+] Saving the new ml-1m userid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.useremb!
    [+] Formatting entity_embedding.weight with columns ['entityid:token', 'entity_embedding:float_seq']
    [+] Saving the new ml-1m entityid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.entityemb!
    [+] Formatting relation_embedding.weight with columns ['relationid:token', 'relation_embedding:float_seq']
    [+] Saving the new ml-1m relationid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.relationemb!
```

Then, you should go to the dataset configuration file (in our case is in `hopwise/properties/dataset/ml-1m.yaml`) and add the new files to be loaded


```text
    additional_feat_suffix: [useremb, entityemb, relationemb]  
    load_col:                                                  
        useremb: [userid, user_embedding]
        entityemb: [entityid, entity_embedding]
        relationemb: [relationid, relation_embedding]
    
    alias_of_user_id: [userid]
    alias_of_entity_id: [entityid]
    alias_of_relation_id: [relationid]
    
    preload_weight:
      userid: user_embedding
      entityid: entity_embedding
      relationid: relation_embedding

```



### The end

Now in your code you should be able to access to pretrained embeddings in your model through:

*Torch*
```python
    pretrained_user_emb = dataset.get_preload_weight('userid')
    pretrained_entity_emb = dataset.get_preload_weight('entityid')
    pretrained_relation_emb = dataset.get_preload_weight('relationid')
    
    self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb))
    self.entity_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_entity_emb))
    self.relation_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_relation_emb))
```

*Numpy*:
```python
    self.pretrained_user_emb = dataset.get_preload_weight('userid')
    self.entity_embedding = dataset.get_preload_weight('entityid')
    self.relation_embedding = dataset.get_preload_weight('relationid')
```


