## Install Packages

In [1]:
!pip install google-generativeai
!pip install python-dotenv



## Import Packages

In [2]:
import os
import google.generativeai as genai
from dotenv import load_dotenv
import threading
import time
import concurrent

## Load API keys
NB! to use this model at least 6 google API keys are needed.<br>
This can either be done by using `.env` file or inputing them into the code.<br>
The env file should be formated like
```
GOOGLE_API_KEY_1=abc
GOOGLE_API_KEY_2=abc
GOOGLE_API_KEY_3=abc
GOOGLE_API_KEY_4=abc
GOOGLE_API_KEY_5=abc
GOOGLE_API_KEY_6=abc
```

In [3]:
api_keys = []
if os.path.isfile(".env"):
    load_dotenv()
    api_key_1 = os.getenv('GOOGLE_API_KEY_1')
    api_key_2 = os.getenv('GOOGLE_API_KEY_2')
    api_key_3 = os.getenv('GOOGLE_API_KEY_3')
    api_key_4 = os.getenv('GOOGLE_API_KEY_4')
    api_key_5 = os.getenv('GOOGLE_API_KEY_5')
    api_key_6 = os.getenv('GOOGLE_API_KEY_6')
    
    api_keys = [
        api_key_1,
        api_key_2,
        api_key_3,
        api_key_4,
        api_key_5,
        api_key_6
    ]
else:
    api_keys = [
        "<GOOGLE_API_KEY_1>",
        "<GOOGLE_API_KEY_2>",
        "<GOOGLE_API_KEY_3>",
        "<GOOGLE_API_KEY_4>",
        "<GOOGLE_API_KEY_5>",
        "<GOOGLE_API_KEY_6>",
    ]

## Loading datasets

In [4]:
filenames = [
    "group_1.txt",
    "group_2.txt",
    "group_3.txt",
    "group_4.txt",
    "group_5.txt",
    "group_6.txt",
    "group_7.txt",
    "group_8.txt",
    "group_9.txt",
    "group_10.txt",
    "group_11.txt",
    "group_12.txt",
]
file_contents = []
for filename in filenames:
    file = open(f"./data/{filename}", "r")
    file_content = file.read()
    file.close()
    file_contents.append(file_content)

In [5]:
def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█', printEnd = "\r"):
    """
    from: https://stackoverflow.com/questions/3173320/text-progress-bar-in-terminal-with-block-characters
    
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
        printEnd    - Optional  : end character (e.g. "\r", "\r\n") (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print(f'\r{prefix} |{bar}| {percent}% {suffix}', end = printEnd)
    # Print New Line on Complete
    if iteration == total: 
        print()

In [6]:
system_prompt = "Sa oled seaduste abiline. Kui sulle antud seaduses on küsimusele vastust, siis vasta sellele. Kui vastus puudub, siis vasta '0'"
response_format = """kastuta vastamisel järgmist formaati:
<seaduse nimi>

<sinu vastus>"""

In [7]:
thread_local = threading.local()
responses_lock = threading.Lock()
api_keys_lock = threading.Lock()
responses = []


def get_session():
    if not hasattr(thread_local, "session"):
        thread_local.session = requests.Session()
    return thread_local.session

def get_response(data):
    api_keys, file_content, group_index, group_count, system_prompt, response_format, question = data
    
    with api_keys_lock:
        if not api_keys:
            print("No API keys left to process.")
            return
        api_key = api_keys.pop(0)
        
    genai.configure(api_key=api_key)

    model = genai.GenerativeModel("gemini-1.5-flash")
    if group_index % 2 == 0:
        model = genai.GenerativeModel("gemini-1.5-flash-8b")


    time.sleep(2)

    response_text = model.generate_content([
        system_prompt,
        response_format,
        file_content,
        "Küsimus on järgmine:",
        question
    ]).text

    # Safely append to the shared responses list
    with responses_lock:
        responses.append((group_index, response_text))
        
    groups_done = len(responses)
    printProgressBar(groups_done, group_count, length = 50)

def get_all_responses(api_keys, file_contents, system_prompt, response_format, question):
    group_count = len(api_keys)
    tasks = [
        ( api_keys, file_contents[i], i, group_count, system_prompt, response_format, question)
        for i in range(group_count)
    ]

    with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
        executor.map(get_response, tasks)

def contorl_threading(user_message):
    temp_api_keys = api_keys.copy()
    temp_api_keys = temp_api_keys * 2
    temp_api_keys.sort()

    start_time = time.time()

    group_count = len(api_keys)
    printProgressBar(0, group_count, length = 50)

    get_all_responses(temp_api_keys, file_contents, system_prompt, response_format, user_message)
    
    duration = time.time() - start_time

    responses.sort(key=lambda x: x[0])

    print()
    print(f"Got responses in {duration} seconds")


In [8]:
for response in responses:
    print(response)

In [9]:
def get_best_answer(user_message):
    genai.configure(api_key=api_keys[0])
    model = genai.GenerativeModel("gemini-1.5-pro")
    response_texts = [element[1] for element in responses]

    print("Choosing best answer")
    start_time = time.time()
    
    prompt = (
        ["Sulle antakse mitme erineva mudeli vastused erinevatest seadustest, tsiteeri mulle mudeli vastuseid, mis pole '0'"] +
        [response_format] +
        response_texts + 
        ["Küsimus on järgnev:"] +
        [user_message]
    )

    duration = time.time() - start_time
    print(f"Chose best answer in {duration} seconds")
    
    best_answer = model.generate_content(prompt)
    return best_answer

In [10]:
def full_flow(user_message):
    contorl_threading(user_message)
    return get_best_answer(user_message).text

---
# Testing chat ui
---

In [None]:
import os
import time


terminal_colors = {
    "purple": "\033[0;35m",
    "end": "\033[0m"
}

# tell the chatbot what its purpose is about
#messages = [
#    {"role": "system", "content" : """You are my weekly mental health doctor and will help me maintain a positive mindset. 
#            You will fill me with realistic and optimistic advice to get me through the work week. 
#            Remember that I work from 9am to 6pm and rest of the day is with my family."""}
#]

def generate_response(user_message):
    responses = []
    answer = full_flow(user_message)

    #messages.append( {
    #    "role" : "assistant",
    #    "content": answer
    #})
    return answer

print(f"{terminal_colors['purple']}Chatbot: Hello! Type 'exit' to end the conversation.{terminal_colors['end']}")

while True:
    user_input = input("You: ")
    if user_input.lower() == 'exit':
        break
    print
    response = generate_response(user_input)
    print(f"{terminal_colors['purple']}Chatbot: {response}{terminal_colors['end']}")

[0;35mChatbot: Hello! Type 'exit' to end the conversation.[0m


You:  Kui palju peab kiiruse ületamise eest trahvi maksma


 |█████████████████████████-------------------------| 50.0% 
Got responses in 248.6593677997589 seconds
Choosing best answer
Chose best answer in 2.6226043701171875e-06 seconds
[0;35mChatbot: <Väärteomenetluse seadustik>

Kiiruse ületamise eest trahvi suurust ei ole võimalik kindlaks teha ilma kontekstita.  Seadustik sisaldab kiiruse ületamise eest rahatrahvi määramise üldisi sätteid, ent konkreetne summa sõltub mitmest tegurist, mis on seaduses sätestatud.  Konkreetse kiiruse ületamise väärteo rahatrahvi suuruse määramiseks on vaja teada:  määratud piirkiirus, mitu protsenti ja millise summa võrra on kiiruse norm ületatud.
[0m
