In [14]:
# imports
import torch
from typing import List
import requests
from retry import retry

import config

In [15]:
# config embedding API
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
hf_token = config.hf_key # your Hugging Face API key

api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{EMBEDDING_MODEL}"
headers = {"Authorization": f"Bearer {hf_token}"}

In [96]:
@retry(tries=3, delay=10)
def query(texts: List[str]):
    response = requests.post(api_url, headers=headers, json={"inputs": texts})
    result = response.json()
    if isinstance(result, list):
      return result
    elif list(result.keys())[0] == "error":
      raise RuntimeError(
          "The model is currently loading, please re-run the query."
          )

def get_ranks(texts: str):
  '''
      Get ranks of texts based on similarity to first text in list.
  '''
  # get embeddings
  out = torch.as_tensor(query(texts)) # list => tensor
  nExamples, channels = out.shape
  print(f"computed {nExamples} embeddings with {channels} channels each")

  # similarity function
  cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)

  # compute embeddings
  res = {}
  for idx, (text, emb) in enumerate(zip(texts, output)):
      if idx == 0:
          # print(f"target prompt: {text}")
          ans_emb = emb
          continue
      res[text] = cos(ans_emb, emb).item()

  # sort descending by similarity
  res = {k: v for k, v in sorted(res.items(), key=lambda item: item[1], reverse=True)}

  return res

In [97]:
prompt = "A lonely cat staring at the moon"
pred1 = "A wistful cat looking at the stars"
pred2 = "A philosophical cat gazing into the universe"
pred3 = "A lonely cat looking at mars"

texts = [prompt, pred1, pred2, pred3]

res = get_ranks(texts)

computed 4 embeddings with 384 channels each


In [112]:

print(f'{"Prompt":>11}: {prompt}\n')
for i, (t, s) in enumerate(res.items()):
  l = '{:>12}  {:<50}  {:<12}'.format(f"Rank {i+1}:", \
                                            f"\"{t}\"", \
                                            f"Score: {s}")
  print(l)


     Prompt: A lonely cat staring at the moon

     Rank 1:  "A lonely cat looking at mars"                      Score: 0.7721485495567322
     Rank 2:  "A wistful cat looking at the stars"                Score: 0.6814246773719788
     Rank 3:  "A philosophical cat gazing into the universe"      Score: 0.6462261080741882
