In [1]:
!pip -q install langchain_huggingface

In [17]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/MuTual"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain.prompts import PromptTemplate
import torch, pandas as pd, os
from tqdm.auto import tqdm
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"

csv_url  = ("https://raw.githubusercontent.com/beefed-up-geek/"
            "HCLT-KACL-2025/main/MuTual/data/mutual_dev.csv")

csv_path = "mutual_dev.csv"

if not os.path.exists(csv_path):
    r = requests.get(csv_url)
    r.raise_for_status()
    with open(csv_path, "wb") as f:
        f.write(r.content)
else:
    print("이미 다운로드된 파일이 존재합니다 – 건너뜀.")

이미 다운로드된 파일이 존재합니다 – 건너뜀.


In [13]:
model_id = "maywell/Llama-3-Ko-8B-Instruct"

tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto"
)

gen_pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tok,
    max_new_tokens=2,
    do_sample=False,
    eos_token_id=[
        tok.eos_token_id,
        tok.convert_tokens_to_ids("<|eot_id|>")
    ]
)

llm = HuggingFacePipeline(pipeline=gen_pipe)

sys_prompt = '''
You are an AI assistant that predicts the next line of a dialogue.
Given the conversation and four candidate replies,
respond with ONLY the single capital letter (A, B, C, or D).'''


template = '''
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system}\n<|eot_id|><|start_header_id|>user<|end_header_id|>

### Conversation
{article}

### Candidates
{options}

### Question
Which option is the most natural next utterance?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
'''

prompt = PromptTemplate(
    input_variables=["system", "article", "options"],
    template=template,
)

chain = prompt | llm

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [4]:
import re

def extract_choice(raw: str) -> str | None:
    tail = raw.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
    m = re.search(r"\b([A-D])\b", tail)
    return m.group(1) if m else None


In [5]:
df = pd.read_csv("mutual_dev.csv")
first = df.iloc[0]
answer = chain.invoke({
    "system": sys_prompt,
    "article": first.article,
    "options": first.options,
})
print("Model answer:", extract_choice(answer.strip()))

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Model answer: A


In [14]:
responses, correctness = [], []
for _, row in tqdm(df.iterrows(), total=len(df), desc="Running MuTual dev"):
    reply = chain.invoke({
        "system": sys_prompt,
        "article": row.article,
        "options": row.options,
    })
    choice = extract_choice(reply)
    responses.append(choice)
    correctness.append(choice == row.answers)

df["llm_response"] = responses
df["is_right"]     = correctness

Running MuTual dev:   0%|          | 0/886 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignore

In [18]:
out_csv = model_id.replace("/", "_") + ".csv"
save_path = os.path.join(SAVE_DIR, out_csv)
df.to_csv(save_path, index=False, encoding="utf-8")

acc = df["is_right"].mean()
print(f"✅  Saved to: {save_path}")
print(f"📊  Accuracy: {acc:.3%}")

✅  Saved to: /content/drive/MyDrive/MuTual/maywell_Llama-3-Ko-8B-Instruct.csv
📊  Accuracy: 24.492%
