In [2]:
import os
import os.path as osp
import sys
import fire
import json
from typing import List, Union
import torch
from torch.nn import functional as F
import transformers
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    set_peft_model_state_dict
)
from peft import PeftModel

device = 'auto'
base_LLM_model = 'meta-llama/Llama-2-7b-hf'



In [3]:
# 9-1. 훈련된 LoRA layer와 base LLM 병합(merge)

torch.cuda.empty_cache()

base_model = AutoModelForCausalLM.from_pretrained(
    base_LLM_model,
    #load_in_8bit=True, # LoRA
    #load_in_4bit=True, # Quantization Load
    torch_dtype=torch.float16,
    device_map="cpu")

model = PeftModel.from_pretrained(base_model, './output/checkpoint-761_a', device)
model = model.merge_and_unload().to("cuda")  # Merge!

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained(
              base_LLM_model,
              trust_remote_code=True,
              eos_token = '</s>' 
              )

In [5]:
def tokenize(element):
    tokenizer.pad_token = tokenizer.eos_token
    outputs = tokenizer(
        element['input'],
        truncation=True,
        max_length=128,
        return_overflowing_tokens=False,
        return_length=True,
        padding=True
    )
    input_batch = []
    for inputs, input_ids, labels in zip(element["input"], outputs["input_ids"], element['label']):
        input_batch.append(input_ids)
    return {"input_ids": input_batch}

In [6]:
prompt = """\
Is this fake news? article: Ronaldo Joins Al Nassr: A New Star in Saudi answer:"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs.to("cuda")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length = 27)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

'Is this fake news? article: Ronaldo Joins Al Nassr: A New Star in Saudi answer: False'

In [7]:
import tkinter as tk
from tkinter import scrolledtext, messagebox
import random
def check_fake_news():
    article = entry_var.get().strip()  # 공백 제거하고 텍스트 가져오기
    if not article:
        messagebox.showinfo("Input Required", "Please enter an article title to check.")
        return

    chat_history.config(state=tk.NORMAL)
    chat_history.insert(tk.END, f"You: {article}\n", "user")
    chat_history.see(tk.END)
    prompt_format1 = """Determine if the given article is fake. article:"""
    prompt_format2 = """Is this article fake? article:"""
    prompt_format3 = """Return True if the given article is fake. article:"""
    
    prompt_templates = [prompt_format1,prompt_format2,prompt_format3]
    chosen_prompt = prompt_templates[random.randint(0, len(prompt_templates)-1)]
    prompt = f'{chosen_prompt} article: {article} answer:'
    inputs = tokenizer(prompt, return_tensors="pt").to('cuda')

    try:
        generate_ids = model.generate(inputs.input_ids, max_length=len(inputs.input_ids[0])+1)
        output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        output_list = output.split()
        
        output = output_list[-1]
        chat_history.insert(tk.END, f"Bot: It's {output} news.\n\n", "bot")
        chat_history.see(tk.END)
    except Exception as e:
        messagebox.showerror("Error", str(e))
    
    entry_var.set("")  # 입력 필드 초기화
    chat_history.config(state=tk.DISABLED)

# UI 구성
root = tk.Tk()
root.title("Fake News Checker Chat")

# 채팅 히스토리 영역
chat_history = scrolledtext.ScrolledText(root, font=("Arial", 12), wrap=tk.WORD, state=tk.DISABLED, bg="#F0F0F0", width=60, height=20)
chat_history.tag_configure("user", foreground="#0084FF")
chat_history.tag_configure("bot", foreground="#FF5733")
chat_history.pack(padx=20, pady=10)

# 입력 필드
entry_var = tk.StringVar()
article_entry = tk.Entry(root, textvariable=entry_var, font=("Arial", 14), width=53)
article_entry.pack(side=tk.LEFT, padx=(20, 0), pady=10)
article_entry.bind("<Return>", lambda event: check_fake_news())  # Enter 키를 누르면 실행

# 검사 버튼
check_button = tk.Button(root, text="Send", command=check_fake_news, font=("Arial", 14), bg="#4CAF50", fg="white")
check_button.pack(side=tk.RIGHT, padx=(0, 20), pady=10)

root.mainloop()

