In [1]:
import torch
import torch.nn as nn
import json

from huggingface_hub import hf_hub_download

In [2]:
REC_MODEL_HF_REPO_ID = "bluebalam/paper-rec"
REC_MODEL_STATE_DICT = "paper-rec-model.pth"
REC_PAPERS_DATA = "papers.jsonl"
LANG_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'

In [3]:
hf_hub_download(repo_id=REC_MODEL_HF_REPO_ID, filename=REC_MODEL_STATE_DICT)

'/Users/bluebalam/.cache/huggingface/hub/7fe676ab4d8ee06dab9eb9aa21867e1cfebc7e56f548795ae2d5a45823281bea.aff0cf7fc39b571a6869b74ab512ee4811bd34fb696168471a2a2c34f5037700'

In [4]:
class MF(torch.nn.Module):
    def __init__(self, item_content_embeddings):
        super().__init__()
        self.item_embeddings = nn.Embedding.from_pretrained(item_content_embeddings)

    def forward(self, user_embedding, item):
        return (user_embedding * self.item_embeddings(item)).sum(1)

In [5]:
papers_data = hf_hub_download(repo_id=REC_MODEL_HF_REPO_ID, filename=REC_PAPERS_DATA)
papers_data

'/Users/bluebalam/.cache/huggingface/hub/823cf8b79d7777b91b6b2d8060119bead3010b4f9feae232c64c16db44db17b6.35aba8c41e11947f3ed330877a7523b238600c692ab801da67914265e30db842'

In [6]:
state_dict_path = hf_hub_download(repo_id=REC_MODEL_HF_REPO_ID, filename=REC_MODEL_STATE_DICT)

In [7]:
model = MF(torch.rand(362, 384))

In [8]:
model.load_state_dict(torch.load(state_dict_path))

<All keys matched successfully>

In [9]:
def load_papers_data():
    papers = []
    papers_data_path = hf_hub_download(repo_id=REC_MODEL_HF_REPO_ID, filename=REC_PAPERS_DATA)
    with open(papers_data_path) as fin:
        for l in fin:
            papers.append(json.loads(l))
    return papers

In [10]:
papers = load_papers_data()

In [11]:
len(papers)

362

In [12]:
x = torch.arange(1., 6.)
x
preds=torch.topk(x, 3)

In [13]:
x

tensor([1., 2., 3., 4., 5.])

In [14]:
preds

torch.return_types.topk(
values=tensor([5., 4., 3.]),
indices=tensor([4, 3, 2]))

In [44]:
indices=preds.indices.numpy()

In [46]:
recs = [papers[i] for i in indices]

In [47]:
recs

[{'id': 'http://arxiv.org/abs/2202.01258',
  'title': 'Accelerated Quality-Diversity for Robotics through Massive Parallelism.',
  'authors': 'Bryan Lim, Maxime Allard, Luca Grillotti, Antoine Cully',
  'abstract': 'Quality-Diversity (QD) algorithms are a well-known approach to generate large collections of diverse and high-quality policies. However, QD algorithms are also known to be data-inefficient, requiring large amounts of computational resources and are slow when used in practice for robotics tasks. Policy evaluations are already commonly performed in parallel to speed up QD algorithms but have limited capabilities on a single machine as most physics simulators run on CPUs. With recent advances in simulators that run on accelerators, thousands of evaluations can performed in parallel on single GPU/TPU. In this paper, we present QDax, an implementation of MAP-Elites which leverages massive parallelism on accelerators to make QD algorithms more accessible. We first demonstrate the

# ---

In [1]:
import pickle

In [10]:
embeddings = None
with open("/Users/bluebalam/Downloads/embeddings.pkl", "rb") as fin:
    embeddings = pickle.load(fin)

In [12]:
embeddings[0]

('http://arxiv.org/abs/2202.01208',
 array([-8.96670576e-03, -8.88680145e-02,  6.79420913e-03, -1.95084829e-02,
        -3.77330631e-02, -2.45450642e-02, -7.14130551e-02, -6.67363927e-02,
        -3.05073522e-02, -3.59350592e-02, -2.55718529e-02,  1.74439792e-02,
        -8.11414677e-04,  2.94881724e-02, -6.02874160e-02, -2.93579660e-02,
         5.70216775e-02,  3.47148366e-02, -4.20632958e-02, -6.61579240e-03,
        -2.53497586e-02,  1.32781286e-02, -5.80083905e-03, -6.03286829e-03,
         3.91915590e-02, -2.60639880e-02, -2.62534227e-02, -3.03233936e-02,
         1.39738237e-02, -3.04945800e-02,  2.78779306e-02, -7.00920634e-03,
         6.41157292e-03, -4.87140380e-02, -1.24554113e-02, -5.03311343e-02,
        -1.04708243e-02,  7.32226018e-03,  1.43467123e-02,  1.01317205e-02,
         1.86174624e-02,  1.43358810e-02,  7.62074627e-03, -1.62231885e-02,
         5.33821844e-02, -4.43339767e-03,  8.61919206e-03, -6.22869991e-02,
        -3.61568900e-03,  4.66444045e-02, -5.6243829