In [1]:
import transformers
from openai import OpenAI, AzureOpenAI
import json
import os
from PIL import Image
from tqdm import tqdm
import wandb
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class CFG:
    seed = 47
    log = True
    print = False
    version = "v.1.6"
    project_name = "Skin_LLaVA_Convgen"
    model_name = "gpt-4-1106-preview"
    num_test_samples = 10
    num_example = 2
    root_dir = "/data2/ArtLab_LLM/label/train_231109/JPEGImages/"
    system_path = "./prompt/system_message.txt"
    sample_1_path = f"./prompt/sample_1_{version}.txt"
    sample_2_path = f"./prompt/sample_2_{version}.txt"

In [4]:
def generate_query_imgdir(sample, root_dir):
    with open(root_dir+sample, 'r') as file:
        data = json.load(file)
    img_dir = root_dir+data["file_name"]

    keys_to_drop = ["caption", "part", "file_name", "rosacea", "acne", "eczema"]
    for key in keys_to_drop:
        data.pop(key, None)

    data["dryness"] = data.pop("hydration", None)

    return {"query":str(data), "img_name":img_dir}

random.seed(CFG.seed)

test_samples =  random.sample([f for f in os.listdir(CFG.root_dir) if f.endswith(".json")], CFG.num_test_samples)

test_container = [generate_query_imgdir(sample, CFG.root_dir) for sample in test_samples]

In [5]:
def generate_conversation(system_message, samples, query, model_name):

    azure = ["gpt-35-turbo", "gpt-35-turbo-16k", "gpt-35-turbo-instruct", "gpt-4", "gpt-4-32k", "gpt-4-1106-preview", "gpt-4-vision-preview"]
    # public = ["gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-instruct", "gpt-4"]

    if model_name in azure:
        client = AzureOpenAI(
            azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"),
            api_key = os.getenv("AZURE_OPENAI_KEY"),
            api_version = "2023-08-01-preview"
        )
    # elif model_name in public:
    #     client = OpenAI(
    #         api_key = os.getenv("PUBLIC_OPENAI_KEY"),
    #     )
    else:
        raise ValueError("Model name is unrecognizable.")
    
    messages = [
        {"role": "system", "content": system_message}
    ]
    for sample in samples:
        messages.append({"role": "user", "content": sample['context']})
        messages.append({"role": "assistant", "content": sample['response']})
    messages.append({"role": "user", "content": query})

    response = client.chat.completions.create(
        model = model_name,
        messages = messages
    )

    content = response.choices[0].message.content
    completion_tokens = response.usage.completion_tokens
    prompt_tokens = response.usage.prompt_tokens

    if model_name in ["gpt-35-turbo", "gpt-3.5-turbo-1106", "gpt-35-turbo", "gpt-35-turbo-instruct"]:
        coef = (0.000002,0.0000015) #completion, prompt
    elif model_name in ["gpt-4"]:
        coef = (0.00006,0.00003)
    elif model_name in ["gpt-4-1106-preview", "gpt-4-vision-preview"]:
        coef = (0.00003,0.00001)
    elif model_name in ["gpt-4-32k", "gpt-4-32k"]:
        coef = (0.00012, 0.00006)
    elif model_name in ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]:
        coef = (0.000004, 0.000003)
    else:
        raise ValueError("Unknown model name")
    
    price = completion_tokens*coef[0] + prompt_tokens*coef[1]

    return content, price

In [6]:
file_container = []

with open(CFG.system_path, 'r') as file:
    system_message = file.read()
if CFG.num_example >= 1:
    with open(CFG.sample_1_path, 'r') as file:
        sample_1 = file.read()
        file_container.append(sample_1)
if CFG.num_example == 2:
    with open(CFG.sample_2_path, 'r') as file:
        sample_2 = file.read()
        file_container.append(sample_2)

def preprocess_example(file):
    context = file.split("\n")[0]
    response = file[len(context):].strip()
    return {"context": context, "response": response}

fewshot_samples = [preprocess_example(file) for file in file_container]

In [7]:
import wandb
from PIL import Image
import os
from tqdm import tqdm

if CFG.log:
    wandb.init(project=CFG.project_name, group=CFG.model_name + " " + CFG.version + " s." + str(CFG.seed), name=CFG.model_name + " " + str(CFG.num_example) + "-shot " + CFG.version)
    table = wandb.Table(columns=['Image Name', 'Image', 'Query', 'Generated Conversation'])

sum = 0

for test_sample in tqdm(test_container):
    conv_generated, price = generate_conversation(system_message, fewshot_samples, test_sample['query'], CFG.model_name)
    sum += price
    sample_image = Image.open(test_sample['img_name'])
    img_name = os.path.basename(test_sample['img_name'])

    if CFG.log:
        table.add_data(img_name, wandb.Image(sample_image), test_sample['query'], conv_generated)
        
    if CFG.print:
        print(os.path.basename(test_sample['img_name']))
        sample_image.show()
        print(test_sample['query'])
        print(conv_generated + "\n\n")

if CFG.log:
    wandb.log({"Log": table})
    
    artifact = wandb.Artifact("prompts", type="dataset")
    artifact.add_file(CFG.system_path, "system_message.txt")
    if CFG.num_example > 0:
        artifact.add_file(CFG.sample_1_path, "sample_1.txt")
        if CFG.num_example > 1:
            artifact.add_file(CFG.sample_2_path, "sample_2.txt")
    wandb.log_artifact(artifact)

print("Price: $", sum)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33m2gnldud[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 10/10 [09:27<00:00, 56.73s/it]


Price: $ 0.24738000000000004
