## Imports

In [1]:
import pandas as pd

import torch
from sentence_transformers import SentenceTransformer
import pickle

In [2]:
# Reading Input Dataset
df = pd.read_csv("../data/input.csv")
df.head()

Unnamed: 0.1,Unnamed: 0,id,input,url
0,0,1,TEXT: For real though; Person in Spider Man ou...,https://i.redd.it/m16dhaqyply21.jpg
1,1,2,TEXT: And that's a fact; Two dogs carry a whit...,https://i.redd.it/z9oh7ligb0i31.jpg
2,2,3,TEXT: It was horrible; man is very dissatisfie...,https://i.redd.it/yves3izsbsj31.jpg
3,3,4,TEXT: This is why Reddit is better; A man that...,https://i.redd.it/y594n8exi6k31.jpg
4,4,5,TEXT: The Area 51 raid is still happening righ...,https://i.redd.it/4hrn18t4lck31.jpg


In [26]:
# Loading model
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
model = model.to("mps")

In [25]:
# Creating embeddings
embeddings = model.encode(df["input"].to_list(), show_progress_bar=True)

Batches:   0%|          | 0/182 [00:00<?, ?it/s]

In [27]:
# Saving the embeddings
with open('../data/meme-embeddings.pkl', "wb") as fOut:
    pickle.dump(embeddings, fOut)

In [28]:
print(f"Shape of one embedding{embeddings[0].shape}")

Shape of one embedding(768,)


In [33]:
# Giving a prompt
from sentence_transformers import util
prompt = "Spiderman"
prompt_embedding = model.encode(prompt, convert_to_tensor=True)
hits = util.semantic_search(prompt_embedding,embeddings, top_k=5) # using cosine simarlity to compare embeddings
hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])

In [34]:
# Output memes
desired_ids = hits["corpus_id"]
filtered_df = df.loc[df['id'].isin(desired_ids)]
retrieved_memes = list(filtered_df["url"])
retrieved_memes

['https://i.redd.it/mi8d2dw3t6u91.jpg',
 'https://i.redd.it/aqmlh6evnjz91.jpg',
 'https://i.redd.it/uqnx4xp7grz91.jpg',
 'https://i.redd.it/oa5ihtb6ti2a1.jpg',
 'https://i.redd.it/ljozr55e9n3a1.jpg']

In [35]:
desired_ids

0    5387
1    4580
2    2615
3    2816
4     479
Name: corpus_id, dtype: int64

In [36]:
from IPython.display import Image,display
[display(Image(url=x,width=200, height=200)) for x in retrieved_memes]

[None, None, None, None, None]