In [None]:
import csv
import torch
from sentence_transformers import SentenceTransformer, util
import pandas as pd
model = SentenceTransformer('allenai-specter', device='cuda')


In [None]:
# pip install sentence_transformers

In [None]:
## load the first csv file as the corpus
corpus = []
with open('data/cs_entity_2022.csv', 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        corpus.append(row)

# preprocess the corpus
corpus_texts = [text[1] + '[SEP]' + (text[2] or '') for text in corpus]

# Compute embeddings for all papers
corpus_embeddings = model.encode(corpus_texts, convert_to_tensor=True,device='cuda')

In [None]:
def calculate_similarity(input_file, output_file, entity1_col, entity2_col, model_name):
    # Load input data from CSV file into a pandas DataFrame
    df = pd.read_csv(input_file)

    # Extract entities and text from DataFrame
    entity1 = df[entity1_col].values
    entity2 = df[entity2_col].values

    # Load model
    model = SentenceTransformer(model_name)

    # Encode corpus texts and query
    query_embeddings = model.encode(entity1 + '[SEP]' + entity2, convert_to_tensor=True, device='cuda')

    # Calculate cosine similarity between query and corpus embeddings
    cosine_scores = util.cos_sim(query_embeddings, corpus_embeddings)

    # Check if query is in corpus
    in_corpus = [query in corpus_texts for query in entity1 + '[SEP]' + entity2]

    # Calculate average cosine similarity
    average_scores = [torch.mean(cosine_scores[i][in_corpus[i]]).item() if in_corpus[i] else 0 for i in range(len(entity1))]
    
    # Add average cosine similarity to DataFrame
    df['average_novelty'] = 1 - average_scores

    # Save output data to CSV file
    df.to_csv(output_file, index=False)

    print("Similarity calculation completed and saved to file:", output_file)

In [None]:
input_file = "input.csv"
output_file = "output.csv"
entity1_col = "entity1"
entity2_col = "entity2"
model_name = "allenai-specter"

calculate_similarity(input_file, output_file, entity1_col, entity2_col, model_name)