In [None]:
import random
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import PreTrainedTokenizer
import os
import logging

def load_jaccard_scores(file_path):
    jaccard_scores = []
    with open(file_path, 'r') as f:
        for line in f:
            scores = list(map(float, line.strip().split()))
            jaccard_scores.append(scores)
    return jaccard_scores

# Function to select K examples based on precomputed Jaccard scores
def select_k_examples_precomputed(k, jaccard_scores, target_index):
    scores = jaccard_scores[target_index]
    # Get indices of top K scores
    selected_example_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
    return selected_example_indices


# Function to create the prompt with K selected examples
def create_few_shot_prompt(k, examples, target_example, jaccard_scores, target_index):
    selected_example_indices = select_k_examples_precomputed(k, jaccard_scores, target_index)
    selected_examples = [examples[i] for i in selected_example_indices]
    instruction_prompt = "This dynamic graph that the interaction of node change overtime."
    example_prompts = [ex for ex in selected_examples]
    predict_prompt = " Belows are some examples of the node have similar interactions with target node."
    full_prompt = instruction_prompt + predict_prompt + "\n".join(example_prompts) + "\n" + f"Predict top 5 node ids will interaction with this target nodes: {target_example}."
    return full_prompt

# Custom dataset to handle the prompts
class PromptDataset(Dataset):
    def __init__(self, examples_file_path, target_example_file_path, k, jaccard_scores):
        with open(examples_file_path, encoding="utf-8") as f:
            examples = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
        self.examples = examples

        with open(target_example_file_path, encoding="utf-8") as f:
            target_examples = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
        self.target_examples = target_examples[:5]

        self.k = k
        self.jaccard_scores = jaccard_scores

    def __len__(self):
        return len(self.target_examples)

    def __getitem__(self, idx):
        target_example = self.target_examples[idx]
        prompt = create_few_shot_prompt(self.k, self.examples, target_example, self.jaccard_scores, idx)
#         print(prompt)
        return prompt

In [None]:
import random
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer


# Initialize model and tokenizer
model_name = "/kaggle/input/llama2-7b-hf/Llama2-7b-hf"  # Ensure this is the correct model name
model = LlamaForCausalLM.from_pretrained(model_name)
tokenizer = LlamaTokenizer.from_pretrained(model_name)


In [None]:
K = 1
batch_size = 1
examples_file_path = '/kaggle/input/uci-13/UCI_13/12/train.link_prediction'
target_example_file_path='/kaggle/input/uci-13/UCI_13/12/test.link_prediction'
jaccard_scores_file = '/kaggle/input/uci-13/UCI_13/12/test.similar_score_jarcard'
jaccard_scores = load_jaccard_scores(jaccard_scores_file)

In [None]:
# print(jaccard_scores)

In [None]:
# Function to predict the next interaction for a batch of prompts
def predict_next_interaction_batch(model, tokenizer, prompts, max_length=1024):
    inputs = tokenizer(prompts, return_tensors="pt", padding=False, truncation=True, max_length=1000)
    outputs = model.generate(**inputs, max_length=max_length)
    predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return predictions

# Create DataLoader for batching
prompt_dataset = PromptDataset(examples_file_path, target_example_file_path, K, jaccard_scores)
prompt_loader = DataLoader(prompt_dataset, batch_size=batch_size, shuffle=False)

# Predict the next interaction for each batch
all_predictions = []
for batch in prompt_loader:
    predictions = predict_next_interaction_batch(model, tokenizer, batch)
#     print("Prediction: ", predictions)
    all_predictions.extend([predictions[0][len(batch[0]):]])
for i, prediction in enumerate(all_predictions):
    print(f"Target Example {i}: {prediction}")