In [8]:
import gemini_f
import pandas as pd

import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold

import sys, os, time, random
sys.path.append("..")
import config

In [10]:
# Used to securely store the API key
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)

for m in genai.list_models():
    if 'generateContent' in m.supported_generation_methods:
        print(m.name)
print()
model = genai.GenerativeModel('gemini-1.5-flash')

# Different prompt-techniques
prompt_techniques = [gemini_f.general_simple, gemini_f.general_complex, gemini_f.domain_simple, gemini_f.domain_complex]

# 1 for force yes or no response, 0 for not
extra_correction = [1]

# Done 
    # dirty_iTunes-Amazon               110
    # dirty_DBLP-ACM                    2474
    # dirty_Walmart-Amazon              2050
    # structured_Beer                   92
    # structured_Fodor-Zagats           190
    # structured_Itunes-Amazon          110
    # structured_Walmart-Amazon         2050
    # structured_Amazon-Google          2294
    # structured_DBLP-ACM               2474
    # textual_Abt-Buy                   1917

# Different folder and datasets
folders = []
datasets = []

save_folder = 'gemini_predictions'
if not os.path.exists(save_folder):
    os.makedirs(save_folder)

total_preds = 0
for folder_name in folders:
    for dataset_name in datasets:
        try:
            train, val, test = config.load_datasets(folder_name, dataset_name)
            total_preds += len(test)*4
        except:
            print(f"Dataset {folder_name}_{dataset_name} does not exist.")
            continue
print(f"Total predictions: {total_preds} \n")

for x, folder_name in enumerate(folders):
    for y, dataset_name in enumerate(datasets):
        try:
            csv_name = f"{folder_name}_{dataset_name}"               

            if not os.path.exists(f"{save_folder}/{csv_name}.csv"):
                with open(f"{save_folder}/{csv_name}.csv", 'w') as f:
                    f.write("general_or_domain,simple_or_complex,force_or_not,tableA_id,tableB_id,pred,label,time\n")

            train, val, test = config.load_datasets(folder_name, dataset_name)
            tableA_df, tableB_df = config.tableA_tableB(folder_name, dataset_name)

            columns = tableA_df.columns
            if 'id' in columns:
                columns = columns.drop('id')

            tableA, tableB, label = test['ltable_id'], test['rtable_id'], test['label']

            for z in range(len(tableA)):
                idA, idB, single_label = tableA.iloc[z], tableB.iloc[z], label.iloc[z]
                rowA = tableA_df[tableA_df['id'] == idA].drop(columns='id')
                rowB = tableB_df[tableB_df['id'] == idB].drop(columns='id')
                sentenceA = gemini_f.format_columns_string(*columns).format(**rowA.to_dict('records')[0])
                sentenceB = gemini_f.format_columns_string(*columns).format(**rowB.to_dict('records')[0])

                for prompt in prompt_techniques:
                    for force in extra_correction:
                        domain = gemini_f.determine_domain(dataset_name) if prompt not in [gemini_f.general_simple, gemini_f.general_complex] else None
                        prompt_sentence = gemini_f.generate_prompt_sentence(sentenceA, sentenceB, force, prompt, domain)
                        
                        # Hopefully more responses with the safety settings, buggy Gemini
                        safety_settings={
                            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
                            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
                        }
                        
                        response = "nothing yet."
                        # Generate random integer between 0.5 and 0.75 and sleep
                        sleep_time = random.uniform(0.25, 0.75)
                        time.sleep(sleep_time)
                        
                        start = time.time()
                        response = response = model.generate_content(
                            prompt_sentence,
                            generation_config=genai.types.GenerationConfig(
                                # Only one candidate for now.
                                candidate_count=1,
                                stop_sequences=['x'],
                                max_output_tokens=10), 
                            safety_settings=safety_settings
                        )
                        end = time.time()
                        try:
                            response = response.text
                        except Exception as e:
                            response = "Gemini response.text failed."
                            print(response)
                            continue
                        
                        time_taken = end - start
                        pred = gemini_f.parse_response(response)
                        simple_or_complex = gemini_f.determine_complexity(prompt)
                        general_or_domain = 'domain' if prompt in [gemini_f.domain_simple, gemini_f.domain_complex] else 'general'
                        yes_or_no = 1 if force else 0

                        gemini_f.save_predictions(f"{save_folder}/{csv_name}.csv", general_or_domain, simple_or_complex, yes_or_no, idA, idB, pred, single_label, time_taken)
        except Exception as e:
            print(f"Dataset {folder_name}_{dataset_name} does not exist. {e}")
            continue

models/gemini-1.0-pro
models/gemini-1.0-pro-001
models/gemini-1.0-pro-latest
models/gemini-1.0-pro-vision-latest
models/gemini-1.5-flash
models/gemini-1.5-flash-001
models/gemini-1.5-flash-latest
models/gemini-1.5-pro
models/gemini-1.5-pro-001
models/gemini-1.5-pro-latest
models/gemini-pro
models/gemini-pro-vision

Total predictions: 9172 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

Yes 

No 

Yes 

No 

No 

No 

No 

No 

Yes 

Yes 

Yes 

No 

Yes 

No 

No 

No 

No 

No 

No 

No 

Yes 

No 

Yes 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 

No 
