In [4]:
%pip install -U datasets
%pip install openai scikit-learn pandas matplotlib seaborn tqdm

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [5]:
from datasets import load_dataset

ds = load_dataset("nnudee/Thai-Thangkarn-sentence", split = 'train')
ds = ds.class_encode_column("label")
ds = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed = 1122)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
ds_pd = ds['test'].to_pandas()
ds_pd.reset_index(inplace = True)
ds_pd

Unnamed: 0,index,label,contact,category,type,output,reasoning,model
0,0,1,Chat,attendance issues,urgent matter,อาจารย์คะ ขอแจ้งว่าหนูมีเหตุฉุกเฉินที่บ้าน ต้อ...,ข้อความนี้ใช้ภาษาสุภาพแต่เป็นประโยคที่อ่านง่าย...,gpt-4.1-2025-04-14
1,1,1,Chat,submission notifications,technical problem,อาจารย์คะ ผมมีปัญหากับไฟล์ที่ส่งไปครับ มันเกิด...,"This text is semi-formal, using polite and app...",typhoon-v2-70b-instruct
2,2,4,Email,attendance issues,late arrival,สวัสดีค่ะ อาจารย์ วันนี้ตัวเองอาจจะเข้าชั้นเรี...,"This text is informal but polite, using everyd...",typhoon-v2-70b-instruct
3,3,1,Chat,attendance issues,request leave,สวัสดีค่ะ อาจารย์ วันนี้ผู้จัดทำอาจจะมาไม่ทันเ...,"This text is semi-formal, using polite and app...",typhoon-v2-70b-instruct
4,4,3,Email,attendance issues,absence notification,ข้าพระพุทธเจ้าใคร่ขอกราบเรียนอาจารย์ด้วยความเค...,"This text is highly ceremonial and formal, usi...",typhoon-v2-70b-instruct
...,...,...,...,...,...,...,...,...
8436,8436,1,Email,attendance issues,absence notification,เรียนอาจารย์ ฉันขอแจ้งลาการเข้าเรียนในวันพรุ...,"This text is semi-formal, using polite and app...",typhoon-v2-70b-instruct
8437,8437,4,Chat,document requests,official letter,ครูคะ หนูขอแบบฟอร์มหน่อยได้ไหมคะ ขอบคุณค่ะ,"This text is informal but polite, using everyd...",typhoon-v2-70b-instruct
8438,8438,3,Email,document requests,recommendation letter,ข้าพระพุทธเจ้าใคร่ขอกราบเรียนศาสตราจารย์ด้วยคว...,"This text is highly ceremonial and formal, usi...",typhoon-v2-70b-instruct
8439,8439,4,Email,submission notifications,delay in assignment,สวัสดีค่ะอาจารย์ ฉันขอแจ้งว่าจะมีการเลื่อนการส...,"This text is informal but polite, using everyd...",typhoon-v2-70b-instruct


In [7]:
import time
from openai import OpenAI
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import csv

API_KEY = "sk-D3qTxqNGOsdWVWfjD0dRwDDXxaqjvNRyzy05TiBEzNWtcWyw\r\n" #userdata.get('Typhoon_API')

In [8]:
with open("./baseline_system.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read().strip()

with open("./baseline_user.txt", "r", encoding="utf-8") as f:
    base_user_prompt = f.read().strip()

client = OpenAI(
    api_key=API_KEY.strip(),
    base_url="https://api.opentyphoon.ai/v1"
)

def extract_label_from_response(response_text, label_list):
    for label in label_list:
        if label.lower() in response_text.lower():
            return label
    return "UNKNOWN"

def call_typhoon(model_name,user_input, client, system_prompt, base_user_prompt):
    full_prompt = f"{base_user_prompt}\n\n{user_input}"
    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": full_prompt}
        ],
        stream=False
    )
    return response.choices[0].message.content.strip()


label_list = ["Casual", "Semi-formal", "Formal", "Ceremonial", "Informal"]
model_name = "typhoon-v2.1-12b-instruct"
results_filename = "baseline_results.csv"
processed_indices = set()

# Retrieved existing results if available
if os.path.exists(results_filename):
    existing_results_df = pd.read_csv(results_filename, encoding='utf-8')
    if existing_results_df.shape[0] > 0:
        processed_indices = set(existing_results_df['index'].astype(int))
        print(f"✅ Found {len(processed_indices)} existing results. Resuming...")
        starting_indice = max(processed_indices) + 1
        
    else:
        print("📝 No existing results found. Starting fresh.")
        starting_indice = 0
    
else:
    print("📝 No existing results found. Starting fresh.")
    starting_indice = 0

data_to_process = ds_pd.loc[starting_indice:]
predicted_labels = []
indexed_list = []
try:
    for _, sample in tqdm(data_to_process.iterrows(), total=len(data_to_process), desc="Processing"):
        # Make the API call
        raw_response = call_typhoon(model_name, sample["output"], client, system_prompt, base_user_prompt)
        predicted_label = extract_label_from_response(raw_response, label_list)
        print(f"Index {sample['index']}: Predicted label: {predicted_label}")
        
        predicted_labels.append(predicted_label)
        indexed_list.append(sample["index"])
        
        time.sleep(5)

except KeyboardInterrupt:
    print("\n🛑 Process interrupted by user. Progress has been saved.")
    results_df = pd.DataFrame({"index": indexed_list, "predicted_label": predicted_labels})
    results_df['model_name'] = model_name
    if starting_indice == 0:
        results_df.to_csv(results_filename, index=False, encoding='utf-8')
    else:
        pd.concat([existing_results_df,results_df], ignore_index=True).to_csv(results_filename, index=False, encoding='utf-8')
    
except Exception as e:
    print(f"\n💥 A critical error occurred: {e}. Progress has been saved. Please check logs.")
    results_df = pd.DataFrame({"index": indexed_list, "predicted_label": predicted_labels})
    results_df['model_name'] = model_name
    
    if starting_indice == 0:
        results_df.to_csv(results_filename, index=False, encoding='utf-8')
    else:
        pd.concat([existing_results_df,results_df], ignore_index=True).to_csv(results_filename, index=False, encoding='utf-8')

print("\n🎉 Processing complete or script stopped. All progress is saved in results.csv.")

✅ Found 10 existing results. Resuming...


Processing:   0%|          | 0/8431 [00:00<?, ?it/s]

Index 10: Predicted label: Formal


Processing:   0%|          | 1/8431 [00:08<20:08:45,  8.60s/it]

Index 11: Predicted label: Casual


Processing:   0%|          | 2/8431 [00:15<17:18:12,  7.39s/it]

Index 12: Predicted label: Semi-formal


Processing:   0%|          | 3/8431 [00:21<16:35:34,  7.09s/it]

Index 13: Predicted label: Ceremonial


Processing:   0%|          | 4/8431 [00:28<16:13:46,  6.93s/it]

Index 14: Predicted label: Semi-formal


Processing:   0%|          | 5/8431 [00:35<15:54:30,  6.80s/it]

Index 15: Predicted label: Semi-formal


Processing:   0%|          | 6/8431 [00:41<15:47:48,  6.75s/it]

Index 16: Predicted label: Ceremonial


Processing:   0%|          | 7/8431 [00:48<15:52:19,  6.78s/it]

Index 17: Predicted label: Formal


Processing:   0%|          | 7/8431 [00:51<17:18:24,  7.40s/it]


🛑 Process interrupted by user. Progress has been saved.

🎉 Processing complete or script stopped. All progress is saved in results.csv.



