## Imports

In [40]:
import pandas as pd

import torch
from sentence_transformers import SentenceTransformer
import pickle

In [41]:
# Reading Input Dataset
df = pd.read_csv("../data/input.csv")
df.tail()

Unnamed: 0.1,Unnamed: 0,id,input,url
5813,5818,5813,TEXT: Posting memes until I get my master's de...,https://i.redd.it/p5t3m0ujt24a1.jpg
5814,5819,5814,TEXT: *pretends to think*; A woman is deep in ...,https://i.redd.it/yyru9fqbu24a1.jpg
5815,5820,5815,TEXT: overreaction indeed; Captain America tal...,https://i.redd.it/n2qnaoztu24a1.jpg
5816,5821,5816,TEXT: Anime watchers be like; two drawn faces ...,https://farm66.staticflickr.com/65535/52761413...
5817,5822,5817,"TEXT: Gentlemen, is with great pleasure to inf...",https://i.redd.it/f6b53ppyx24a1.jpg


In [42]:
# Loading model
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
model = model.to("mps")

In [43]:
# Creating embeddings
embeddings = model.encode(df["input"].to_list(), show_progress_bar=True)

Batches:   0%|          | 0/182 [00:00<?, ?it/s]

In [44]:
# Saving the embeddings
with open('../data/meme-embeddings.pkl', "wb") as fOut:
    pickle.dump(embeddings, fOut)

In [45]:
print(f"Shape of one embedding{embeddings[0].shape}")

Shape of one embedding(768,)


In [46]:
# Giving a prompt
from sentence_transformers import util
prompt = "Spiderman giving lecture"
prompt_embedding = model.encode(prompt, convert_to_tensor=True)
hits = util.semantic_search(prompt_embedding,embeddings, top_k=5) # using cosine simarlity to compare embeddings
hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])

In [47]:
# Output memes
desired_ids = hits["corpus_id"]
filtered_df = df.loc[df['id'].isin(desired_ids)]
retrieved_memes = list(filtered_df["url"])
retrieved_memes

['https://i.redd.it/m16dhaqyply21.jpg',
 'https://i.redd.it/obobf6u1exx91.jpg',
 'https://i.redd.it/gmnwi9dxdfy91.jpg',
 'https://farm66.staticflickr.com/65535/52761820325_bc9f10f3c6.png',
 'https://i.redd.it/wm1nxjuehy3a1.jpg']

In [49]:
from IPython.display import Image,display
[display(Image(url=x,width=200, height=200)) for x in retrieved_memes]

[None, None, None, None, None]