In [None]:
import gemini_f
import pandas as pd
import textwrap

import google.generativeai as genai
from dotenv import load_dotenv

from IPython.display import display
from IPython.display import Markdown

import sys, os, inspect, time
sys.path.append("..")
import config

def to_markdown(text):
  text = text.replace('â€¢', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

In [None]:
# Used to securely store the API key
load_dotenv()
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)

for m in genai.list_models():
    if 'generateContent' in m.supported_generation_methods:
        print(m.name)

model = genai.GenerativeModel('gemini-pro')

# Different prompt-techniques
prompt_techniques = [gemini_f.general_simple, gemini_f.general_complex, gemini_f.domain_simple, gemini_f.domain_complex]

# 1 for force yes or no response, 0 for not
extra_correction = [1]

# Different folder and datasets
folders = [config.STRUCTURED_DIR, config.DIRTY_DIR, config.TEXTUAL_DIR]

datasets = [config.AMAZON_GOOGLE_DIR, config.BEER_DIR, config.DBLP_ACM_DIR, 
            config.DBLP_GOOGLESCHOLAR_DIR, config.FODORS_ZAGATS_DIR, 
            config.ITUNES_AMAZON_DIR, config.WALMART_AMAZON_DIR
]

save_folder = 'llama3_predictions'
if not os.path.exists(save_folder):
    os.makedirs(save_folder)

total_preds = 0
for folder_name in folders:
    for dataset_name in datasets:
        try:
            train, val, test = config.load_datasets(folder_name, dataset_name)
            total_preds += len(test)*4
        except:
            print(f"Dataset {folder_name}_{dataset_name} does not exist.")
            continue
print(f"Total predictions: {total_preds}")

llama3 = LLama3()
for x, folder_name in enumerate(folders):
    for y, dataset_name in enumerate(datasets):
        try:
            csv_name = f"{folder_name}_{dataset_name}"               

            if not os.path.exists(f"{save_folder}/{csv_name}.csv"):
                with open(f"{save_folder}/{csv_name}.csv", 'w') as f:
                    f.write("general_or_domain,simple_or_complex,force_or_not,tableA_id,tableB_id,pred,label, time\n")

            train, val, test = config.load_datasets(folder_name, dataset_name)
            tableA_df, tableB_df = config.tableA_tableB(folder_name, dataset_name)

            columns = tableA_df.columns
            if 'id' in columns:
                columns = columns.drop('id')

            tableA, tableB, label = test['ltable_id'], test['rtable_id'], test['label']

            for z in range(len(tableA)):
                idA, idB, single_label = tableA.iloc[z], tableB.iloc[z], label.iloc[z]
                rowA = tableA_df[tableA_df['id'] == idA].drop(columns='id')
                rowB = tableB_df[tableB_df['id'] == idB].drop(columns='id')
                sentenceA = gemini_f.format_columns_string(*columns).format(**rowA.to_dict('records')[0])
                sentenceB = gemini_f.format_columns_string(*columns).format(**rowB.to_dict('records')[0])

                for prompt in prompt_techniques:
                    for force in extra_correction:
                        domain = gemini_f.determine_domain(dataset_name) if prompt not in [gemini_f.general_simple, gemini_f.general_complex] else None
                        prompt_sentence = gemini_f.generate_prompt_sentence(sentenceA, sentenceB, force, prompt, domain)
                        
                        start = time.time()
                        response = llama3.llama_chat_get_response(prompt_sentence)
                        end = time.time()
                        time_taken = end - start

                        pred = gemini_f.parse_response(response)
                        simple_or_complex = gemini_f.determine_complexity(prompt)
                        general_or_domain = 'domain' if prompt in [gemini_f.domain_simple, gemini_f.domain_complex] else 'general'
                        yes_or_no = 1 if force else 0

                        gemini_f.save_predictions(f"{save_folder}/{csv_name}.csv", general_or_domain, simple_or_complex, yes_or_no, idA, idB, pred, single_label, time_taken)
        except:
            print(f"Dataset {folder_name}_{dataset_name} does not exist.")
            continue