# Semantic search with FAISS

Install the Transformers and Datasets libraries to run this notebook.

In [None]:
!pip install datasets transformers[sentencepiece]

In [None]:
from huggingface_hub import hf_hub_url

data_files = hf_hub_url(repo_id="lewtun/github-issues", filename="datasets-issues-with-comments.jsonl", repo_type="dataset")

In [None]:
issues_dataset = load_dataset("json", data_files=data_files, split="train")
issues_dataset

Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 2855
})

In [None]:
issues_dataset = (
    issues_dataset
    .filter(lambda x : x["is_pull_request"] == False)
    .filter(lambda x : len(x["comments"]) > 1)
)

In [None]:
columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 582
})

In [None]:
issues_dataset.set_format("pandas")
df = dset[:]
comments_df = df.explode("comments").reset_index()
comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset

Dataset({
    features: ['index', 'html_url', 'title', 'comments', 'body'],
    num_rows: 2653
})

In [None]:
comments_dataset = comments_dataset.map(lambda x : {"comment_length" : len(x["comments"].split())})

In [None]:
comments_dataset = comments_dataset.filter(lambda x : x["comment_length"] > 15)
comments_dataset

Dataset({
    features: ['index', 'html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 1979
})

In [None]:
def concatenate_text(examples):
    return {"text": examples["title"] + " \n " + examples["comments"] + " \n "  + examples["body"]}

comments_dataset = comments_dataset.map(concatenate_text)

In [None]:
from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/msmarco-distilbert-base-v4"
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/nq-distilbert-base-v1')
model = AutoModel.from_pretrained('sentence-transformers/nq-distilbert-base-v1')

In [None]:
import torch

device = torch.device("cuda")
model.to(device)

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
def get_embeddings(text_list):
    encoded_input = tokenizer(text_list, padding=True, truncation=True, return_tensors="pt")
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return mean_pooling(model_output, encoded_input["attention_mask"])

In [None]:
embedding = get_embeddings(explode_dataset["text"][0])
embedding.shape

torch.Size([1, 768])

In [None]:
embeddings_dataset = explode_dataset.map(lambda x: {"embeddings": get_embeddings(x["text"]).cpu().numpy()[0]})

In [None]:
embeddings_dataset.add_faiss_index(column="embeddings")

In [None]:
question = "how to use dataset splits?"
question_embedding = get_embeddings([question]).cpu().numpy()
question_embedding.shape

torch.Size([1, 768])

In [None]:
scores, samples = embeddings_dataset.get_nearest_examples('embeddings', question_embedding, k=5)

In [None]:
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

In [None]:
for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()