In [1]:
AZURE_OPENAI_ENDPOINT = ""
AZURE_OPENAI_KEY = ""
AZURE_OPENAI_VERSION = ""

# https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
AZURE_OPENAI_MODEL_35_0301 = "gpt35turbo_v0301" #it is actually 1106 (has json)
AZURE_OPENAI_MODEL_4_1106 = "gpt4_v1106" #it is actually 1106

## Libraries

In [2]:
# !pip3 install ollama
from openai import AzureOpenAI
import time
import json
import ollama

## GPT 3.5 Turbo (version 1106 with json)

In [3]:
def gpt35turbo_v1106(key, system_prompt, user_prompt):
    try:
        client = AzureOpenAI(azure_endpoint=AZURE_OPENAI_ENDPOINT, 
                             api_key=key, 
                             api_version=AZURE_OPENAI_VERSION)
        
        start_time = time.time()
        
        response = client.chat.completions.create(
            model=AZURE_OPENAI_MODEL_35_0301,
            messages=[{"role": "system", "content": system_prompt}, 
                      {"role": "user", "content": user_prompt}],
            response_format={"type": "json_object"}
        )
        
        execution_time = time.time() - start_time
        
        response_content = response.choices[0].message.content
        if isinstance(response_content, str):
            response_content = json.loads(response_content)
        
        response_content['execution_time'] = execution_time
        
        return response_content
    
    except Exception as e:
        print(f"Error during chat: {e}")
        return None

## GPT 4 (version 1106 with json)

In [4]:
def gpt4_v1106(key, system_prompt, user_prompt):
    from openai import AzureOpenAI
    import time

    try:
        client = AzureOpenAI(azure_endpoint=AZURE_OPENAI_ENDPOINT, 
                             api_key=key, 
                             api_version=AZURE_OPENAI_VERSION)
        
        start_time = time.time()
        
        response = client.chat.completions.create(
            model=AZURE_OPENAI_MODEL_4_1106,
            messages=[{"role": "system", "content": system_prompt}, 
                      {"role": "user", "content": user_prompt}],
            response_format={"type": "json_object"}
        )
        
        execution_time = time.time() - start_time
        
        response_content = response.choices[0].message.content
        if isinstance(response_content, str):
            response_content = json.loads(response_content)
        
        response_content['execution_time'] = execution_time
        
        return response_content
    
    except Exception as e:
        print(f"Error during chat: {e}")
        return None

## Models via Ollama

In [5]:
def ollama(model, system_prompt, user_prompt):
    import ollama
    try:
        start_time = time.time()
        
        response = ollama.chat(model=model, 
                               messages=[{'role': 'assistant', 'content': system_prompt},
                                         {'role': 'user', 'content': user_prompt}],
                               format='json'
                         )
        
        execution_time = time.time() - start_time
        
        response_content = json.loads(response['message']['content'])
        response_content['execution_time'] = execution_time
        
        return response_content
    except Exception as e:
        print(f"Error during chat: {e}")
        return None

In [6]:
# model = 'llama3:latest'
# system_prompt ="You are a helpful agent. Your answers must be a json with key 'output'"
# user_prompt = "hi!"
# ollama(model, system_prompt, user_prompt)

## Function to validate the output of a LLM

In [7]:
def llm_output_validator(response_content, required_keys=None):
    if response_content is None:
        raise ValueError("Response content is None")
    
    try:
        if isinstance(response_content, str):
            response_json = json.loads(response_content)
        else:
            response_json = response_content
        
        if required_keys and not all(key in response_json for key in required_keys):
            raise ValueError(f"Missing required keys: {required_keys}")
        
        return response_json
    except ValueError as e:
        raise ValueError(f"Invalid response content: {e}")

# Merging all models

In [8]:
def llm(model, system_prompt, user_prompt, key=None, max_attempts=1, required_keys=None):
    
    attempts = 0

    while attempts < max_attempts:
        if model == 'gpt35turbo_v1106':    
            response_content = gpt35turbo_v1106(key, system_prompt, user_prompt)
            
        elif model == 'gpt4_v1106':    
            response_content = gpt4_v1106(key, system_prompt, user_prompt)

        elif model in ['mixtral:latest','phi3:medium','llama3:latest','llama2:latest']:
            response_content = ollama(model, system_prompt, user_prompt)
            
        else:
            raise ValueError("Unsupported model specified.")
        
        try:
            # Validate the output
            valid_output = llm_output_validator(response_content, required_keys)
            return valid_output
        
        except ValueError as e:
            print(f"Attempt {attempts + 1} failed: {e}")
            attempts += 1

    print(f"Output could not be validated within {max_attempts} attempts")
    return {}

# Testing

In [12]:
# system_prompt ="You are a helpful assistant. Your answers must be a json with key 'output'"
# user_prompt = "hi!"

# print('gpt35turbo_v1106:\n',llm('gpt35turbo_v1106', system_prompt, user_prompt, key=AZURE_OPENAI_KEY,max_attempts=3,required_keys=['output']))
# print()
# print('gpt4_v1106:\n',llm('gpt4_v1106', system_prompt, user_prompt, key=AZURE_OPENAI_KEY,max_attempts=3,required_keys=['output']))
# print()
# print('phi3:medium:\n',llm('phi3:medium', system_prompt, user_prompt,max_attempts=3,required_keys=['output']))
# print()
# print('llama3:latest:\n',llm('llama3:latest', system_prompt, user_prompt,max_attempts=3,required_keys=['output']))
# print()
# print('llama2:latest:\n',llm('llama2:latest', system_prompt, user_prompt,max_attempts=3,required_keys=['output']))
# print()
# print('mixtral:latest:\n',llm('mixtral:latest', system_prompt, user_prompt))