In [None]:
# Import necessary libraries
import os
import json
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

In [None]:
# Load the config file to get the ontology categories
config_path = 'config/dbpedia_webnlg_prompt_gen_config.json'
with open(config_path, 'r') as f:
    config = json.load(f)
ontology_list = config['onto_list']

In [None]:
# Initialize the SentenceTransformer model
# Note: T5-XXL is a very large model (11B parameters) and may not run on standard hardware.
# Make sure you have the necessary resources before loading this model.
model_name = 'sentence-t5-xxl'
model = SentenceTransformer(model_name)

In [None]:
# Define paths
test_data_path = '../../data/dbpedia_webnlg/test/'
train_data_path = '../../data/dbpedia_webnlg/train/'
output_path = '../../data/dbpedia_webnlg/baselines/test_train_sent_similarity/'

# Create output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)

In [None]:
# Number of top similar sentences to retrieve
top_k = 5

# Process each ontology category
for ontology in ontology_list:
    print(f'Processing ontology: {ontology}')

    # Load test data
    test_file = os.path.join(test_data_path, f'ont_1_{ontology}_test.jsonl')
    test_sentences = []
    test_ids = []
    with open(test_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            test_sentences.append(data['sent'])
            test_ids.append(data['id'])

    # Load train data
    train_file = os.path.join(train_data_path, f'ont_1_{ontology}_train.jsonl')
    train_sentences = []
    train_ids = []
    with open(train_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            train_sentences.append(data['sent'])
            train_ids.append(data['id'])

    # Compute embeddings for test and train sentences
    print('Computing embeddings for test sentences...')
    test_embeddings = model.encode(test_sentences, convert_to_tensor=True, show_progress_bar=True)
    print('Computing embeddings for train sentences...')
    train_embeddings = model.encode(train_sentences, convert_to_tensor=True, show_progress_bar=True)

    # Compute similarities and find top-k similar sentences for each test sentence
    similarity_results = {}
    print('Computing similarities and finding top similar sentences...')
    for idx, test_embedding in enumerate(tqdm(test_embeddings)):
        # Compute cosine similarities
        cosine_scores = util.cos_sim(test_embedding, train_embeddings)[0]
        # Get the top_k results
        top_results = torch.topk(cosine_scores, k=top_k)
        similar_train_ids = [train_ids[i] for i in top_results[1]]
        # Map test ID to similar train IDs
        similarity_results[test_ids[idx]] = similar_train_ids

    # Save the results to the output file
    output_file = os.path.join(output_path, f'{ontology}_test_train_similarity.json')
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(similarity_results, f, indent=4)

    print(f'Results saved to {output_file}\n')

print('Processing completed.')