# Intelligent Chat Bot for Korean Live Streaming Flatforms

Video live streaming services are interactive media contents between streamers and viewers by live chats. As streamers are willing to manage chats smarter, intelligent chat bots are needed. This project is making live stream chat bots with Gemma2 2B model as AfreecaTV(SOOP) extension program in order to detect questions from live chats and answer them automatically, or check inappropriate texts and remove various types of spams.

Primary goal is to classify questions from live chats and answer them by fine tuned Gemma model.

Project: Gemma Sprint Project @ 2024 Google Machine Learning Bootcamp Korea

Autor: Seonghyeok Jo, Kangwon National University.

In [None]:
!pip install -q websockets api numpy pandas altair torch scikit_learn==1.4.2 pyarrow==15.0.2 datasets peft trl accelerate transformers bitsandbytes

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.1/12.1 MB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.3/38.3 MB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m164.1/164.1 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m322.5/322.5 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.4/318.4 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.6 MB/s[0m eta [36

## 1. Live Chat Crawling with AfreecaTV API


We need to get live chats for input. There is a magnificent code to crawl chats. Let's use this.

AfreecaTV Chat Crawler by **Soohyun-Chae(cha2hyun)** : [GitHub Repo](https://github.com/cha2hyun/afreecatv-chat-crawler)

In [None]:
import certifi
import ssl
import asyncio
import websockets
import requests

# 유니코드 및 기타 상수
# Unicode and other constants
F = "\x0c"
ESC = "\x1b\t"
SEPARATOR = "+" + "-" * 70 + "+"

In [None]:
# 아프리카TV에서 제공하는 API로 채팅 정보를 받습니다.
# Receives chat information using the API provided by AfreecaTV.
def get_player_live(bno, bid):
    url = 'https://live.afreecatv.com/afreeca/player_live_api.php'
    data = {
        'bid': bid,
        'bno': bno,
        'type': 'live',
        'confirm_adult': 'false',
        'player_type': 'html5',
        'mode': 'landing',
        'from_api': '0',
        'pwd': '',
        'stream_type': 'common',
        'quality': 'HD'
    }

    try:
        response = requests.post(f'{url}?bjid={bid}', data=data)
        # HTTP 요청 에러를 확인하고, 에러가 있을 경우 예외를 발생시킵니다.
        # Checks for HTTP request errors and raises an exception if there are any.
        response.raise_for_status()
        res = response.json()

        CHDOMAIN = res["CHANNEL"]["CHDOMAIN"].lower()
        CHATNO = res["CHANNEL"]["CHATNO"]
        FTK = res["CHANNEL"]["FTK"]
        TITLE = res["CHANNEL"]["TITLE"]
        BJID = res["CHANNEL"]["BJID"]
        CHPT = str(int(res["CHANNEL"]["CHPT"]) + 1)

        return CHDOMAIN, CHATNO, FTK, TITLE, BJID, CHPT

    except requests.RequestException as e:
        print(f"  ERROR: API 요청 중 오류 발생: {e}")
        return None
    except KeyError as e:
        print(f"  ERROR: 응답에서 필요한 데이터를 찾을 수 없습니다: {e}")
        return None

In [None]:
# SSL 컨텍스트 생성
# Create SSL context.
def create_ssl_context():
    ssl_context = ssl.create_default_context()
    ssl_context.load_verify_locations(certifi.where())
    ssl_context.check_hostname = False
    ssl_context.verify_mode = ssl.CERT_NONE
    return ssl_context

In [None]:
# 메시지 디코드 및 출력
# Decode messages and print.
def decode_message(bytes):
    parts = bytes.split(b'\x0c')
    messages = [part.decode('utf-8') for part in parts]
    if len(messages) > 5 and messages[1] not in ['-1', '1'] and '|' not in messages[1]:
        user_id, comment, user_nickname = messages[2], messages[1], messages[6]
        print(SEPARATOR)
        print(f"| {user_nickname}[{user_id}] - {comment}")
    else:
        # 채팅 뿐만 아니라 다른 메세지도 동시에 내려옵니다.
        # Not only chat messages, but other messages also come through at the same time.
        pass

In [None]:
# 바이트 크기 계산
# Calculate byte size.
def calculate_byte_size(string):
    return len(string.encode('utf-8')) + 6

In [None]:
# 채팅에 연결
# Connect to chat.
async def connect_to_chat(url, ssl_context):
    try:
        BNO, BID = url.split('/')[-1], url.split('/')[-2]
        CHDOMAIN, CHATNO, FTK, TITLE, BJID, CHPT = get_player_live(BNO, BID)
        print(f"{SEPARATOR}\n"
              f"  CHDOMAIN: {CHDOMAIN}\n  CHATNO: {CHATNO}\n  FTK: {FTK}\n"
              f"  TITLE: {TITLE}\n  BJID: {BJID}\n  CHPT: {CHPT}\n"
              f"{SEPARATOR}")
    except Exception as e:
        # API call failure.
        print(f"  ERROR: API 호출 실패 - {e}")
        return

    try:
        async with websockets.connect(
            f"wss://{CHDOMAIN}:{CHPT}/Websocket/{BID}",
            subprotocols=['chat'],
            ssl=ssl_context,
            ping_interval=None
        ) as websocket:
            # 최초 연결시 전달하는 패킷
            # Packet sent during the initial connection.
            CONNECT_PACKET = f'{ESC}000100000600{F*3}16{F}'
            # 메세지를 내려받기 위해 보내는 패킷
            # Packet sent to receive messages.
            JOIN_PACKET = f'{ESC}0002{calculate_byte_size(CHATNO):06}00{F}{CHATNO}{F*5}'
            # 주기적으로 핑을 보내서 메세지를 계속 수신하는 패킷
            # Packet that periodically sends pings to keep receiving messages.
            PING_PACKET = f'{ESC}000000000100{F}'

            await websocket.send(CONNECT_PACKET)
            print(f"  연결 성공, 채팅방 정보 수신 대기중...")
            await asyncio.sleep(2)
            await websocket.send(JOIN_PACKET)

            async def ping():
                while True:
                    # 5분동안 핑이 보내지지 않으면 소켓은 끊어집니다.
                    # Disconnects if ping is not recieved for 5 minutes.
                    await asyncio.sleep(60)  # 1분 = 60초
                    await websocket.send(PING_PACKET)

            async def receive_messages():
                while True:
                    data = await websocket.recv()
                    decode_message(data)

            await asyncio.gather(
                receive_messages(),
                ping(),
            )

    except Exception as e:
        # Web socket connection error.
        print(f"  ERROR: 웹소켓 연결 오류 - {e}")

This is the main function of crawling all the live chats from url's broadcast. The broadcast must be not private or age-restricted.

In [None]:
# Enter AfreecaTV URL.
# url = input("아프리카TV URL을 입력해주세요: ")
# ssl_context = create_ssl_context()
# await connect_to_chat(url, ssl_context)

# 2. Fine Tuning Gemma2-2b-it


There is a fine-tuned Gemma model for Korean.

Gemma2-2b-it model by **Google** : [Hugging Face](https://huggingface.co/google/gemma-2-2b-it)

In [None]:
from huggingface_hub import login
from google.colab import userdata

login(userdata.get('HF_TOKEN'))

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-2-2b-it",
)

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    torch_dtype=torch.bfloat16,
    device_map='auto',
)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

In [None]:
# streamer = TextStreamer(tokenizer)

# messages = [
#     {"role": "user", "content": "대한민국의 수도에 대해 알려줘"},
# ]

# input_ids = tokenizer.apply_chat_template(
#     messages,
#     add_generation_prompt=True,
#     return_tensors="pt"
# ).to(model.device)

# terminators = [
#     tokenizer.eos_token_id,
#     tokenizer.convert_tokens_to_ids("<|end_of_turn|>")
# ]

# outputs = model.generate(
#     input_ids,
#     max_new_tokens=512,
#     eos_token_id=terminators,
#     do_sample=False,
#     repetition_penalty=1.05,
#     streamer = streamer
# )
# response = outputs[0][input_ids.shape[-1]:]
# print(tokenizer.decode(response, skip_special_tokens=True))


Here is an example how to classify specific kinds of statement.

Question - Statement model by fine-tuned bert mini model by **Shahrukh Khan** :
[Hugging Face](https://huggingface.co/shahrukhx01/bert-mini-finetune-question-detection)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# tokenizer = AutoTokenizer.from_pretrained("shahrukhx01/bert-mini-finetune-question-detection")
# model = AutoModelForSequenceClassification.from_pretrained("shahrukhx01/bert-mini-finetune-question-detection")

## 3-1. Calling Chat Dataset for Training


In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd
import numpy as np

train = pd.read_csv('/content/drive/MyDrive/datasets/data_split_300K.csv')

In [None]:
train["label"].value_counts(normalize=True)

Unnamed: 0_level_0,proportion
label,Unnamed: 1_level_1
0,0.954987
1,0.045013


In [None]:
import csv
import json
import random
import torch
from datasets import Dataset, DatasetDict

# Read .csv
csv_file_path = '/content/drive/MyDrive/datasets/data_split_300K.csv'

with open(csv_file_path, 'r', encoding='utf-8') as f:
    reader = csv.reader(f)
    next(reader)  # Skip first line

    # Add to list after creating dictionary
    data = []
    for line in reader:
        d = {
            'input': line[0],
            'label': int(line[1])
        }
        data.append(d)

# json to string
json_string = json.dumps(data, ensure_ascii=False, indent=2)
json_file_path = '/content/drive/MyDrive/datasets/train_chatbot.json'

with open(json_file_path, 'w', encoding='utf-8') as f:
    f.write(json_string)

## 3-2. Training Model

In [None]:
import os
import copy
from dataclasses import dataclass
from trl import SFTTrainer

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import (
    BitsAndBytesConfig,
    AutoTokenizer,
    PreTrainedTokenizerBase,
    EvalPrediction,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel
import warnings

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from accelerate import Accelerator

In [None]:
model_id = "google/gemma-2-2b-it"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map='auto',
    attn_implementation='eager'
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
data_path = json_file_path
train_data = Dataset.from_json(data_path)

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
# Setting LoRA
lora_config = LoraConfig(
    r=6,
    lora_alpha=8,
    lora_dropout=0.05,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM"
)

In [None]:
context = "당신은 악성 댓글을 찾아내야 합니다. 악성 댓글이란 입력 중 인격 모독 및 비하, 성별이나 지역감정 관련 정치적 발언, 폭력적 및 선정적 발언, 협박 및 개인정보 유출과 관련된 짧은 문장들입니다. 악성 댓글로 의심되면 다음과 같이 정수형으로만 대답하시기 바랍니다. 부연적인 설명은 없어도 됩니다. 다음 입력이 악성 댓글이라면 1, 아니라면 0으로 대답하시오. "

In [None]:
def generate_prompt(example):
    prompt_list = []
    for i in range(len(example['input'])):
        prompt_list.append(f"""<bos><start_of_turn>user
        {context}
        {example['input'][i]}<end_of_turn>
        <start_of_turn>model
        {example['label'][i]}<end_of_turn><eos>""")
    return prompt_list

In [None]:
print(generate_prompt(train_data[:1])[0])

<bos><start_of_turn>user
        당신은 악성 댓글을 찾아내야 합니다. 악성 댓글이란 입력 중 인격 모독 및 비하, 성별이나 지역감정 관련 정치적 발언, 폭력적 및 선정적 발언, 협박 및 개인정보 유출과 관련된 짧은 문장들입니다. 악성 댓글로 의심되면 다음과 같이 정수형으로만 대답하시기 바랍니다. 부연적인 설명은 없어도 됩니다. 다음 입력이 악성 댓글이라면 1, 아니라면 0으로 대답하시오. 
        저희 참치 접었어요<end_of_turn>
        <start_of_turn>model
        0<end_of_turn><eos>


In [None]:
# Setting Model
training_args = SFTConfig(
    output_dir = 'output',
    overwrite_output_dir = True,
    do_train = True,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0.01,
    num_train_epochs = 3,
    max_steps = -1,
    lr_scheduler_type = "cosine",
    warmup_ratio = 0.1,
    log_level = "info",
    logging_steps = 10,
    save_strategy = "epoch",
    bf16 = True,
    gradient_checkpointing = False,
    gradient_checkpointing_kwargs = {"use_reentrant": False},
    max_seq_length = 8,
    seed = 42,
    report_to = "none",
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_data,
    args = training_args,
    formatting_func = generate_prompt
)

Map:   0%|          | 0/300004 [00:00<?, ? examples/s]

Using cpu_amp half precision backend


In [None]:
model.config.use_cache = False

trainer.train()

***** Running training *****
  Num examples = 300,004
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 900,012
  Number of trainable parameters = 2,614,341,888
  ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


In [None]:
# Save Adapter Model
ADAPTER_MODEL = "lora_adapter"
trainer.model.save_pretrained(ADAPTER_MODEL)

# Merging Final Model
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', torch_dtype=torch.float16)

## 3-3. Load Model for Chat Filtering

In [None]:
cleanbot_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)

## 3-4. Approved Chat Crawler Code with Filter

In [None]:
def filter_with_model(input_text):
    # Generate prompt
    messages = [{"role": "user", "content": input_text}]

    # Input to prompt format by tokenizer
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # Response by Model
    response = cleanbot_pipeline(prompt, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, add_special_tokens=True)

    # Return response
    generated_text = response[0]["generated_text"]
    model_response = generated_text[len(prompt):]  # Delete input prompt
    return model_response

In [None]:
# 아프리카TV에서 제공하는 API로 채팅 정보를 받습니다.
# Receives chat information using the API provided by AfreecaTV.
def get_player_live(bno, bid):
    url = 'https://live.afreecatv.com/afreeca/player_live_api.php'
    data = {
        'bid': bid,
        'bno': bno,
        'type': 'live',
        'confirm_adult': 'false',
        'player_type': 'html5',
        'mode': 'landing',
        'from_api': '0',
        'pwd': '',
        'stream_type': 'common',
        'quality': 'HD'
    }

    try:
        response = requests.post(f'{url}?bjid={bid}', data=data)
        # HTTP 요청 에러를 확인하고, 에러가 있을 경우 예외를 발생시킵니다.
        # Checks for HTTP request errors and raises an exception if there are any.
        response.raise_for_status()
        res = response.json()

        CHDOMAIN = res["CHANNEL"]["CHDOMAIN"].lower()
        CHATNO = res["CHANNEL"]["CHATNO"]
        FTK = res["CHANNEL"]["FTK"]
        TITLE = res["CHANNEL"]["TITLE"]
        BJID = res["CHANNEL"]["BJID"]
        CHPT = str(int(res["CHANNEL"]["CHPT"]) + 1)

        return CHDOMAIN, CHATNO, FTK, TITLE, BJID, CHPT

    except requests.RequestException as e:
        print(f"  ERROR: API 요청 중 오류 발생: {e}")
        return None
    except KeyError as e:
        print(f"  ERROR: 응답에서 필요한 데이터를 찾을 수 없습니다: {e}")
        return None

In [None]:
# SSL 컨텍스트 생성
# Create SSL context.
def create_ssl_context():
    ssl_context = ssl.create_default_context()
    ssl_context.load_verify_locations(certifi.where())
    ssl_context.check_hostname = False
    ssl_context.verify_mode = ssl.CERT_NONE
    return ssl_context

In [None]:
# 검열된 메시지 디코드 및 출력
# Decode filtered messages and print.
def decode_message_approved(bytes):
    parts = bytes.split(b'\x0c')
    messages = [part.decode('utf-8') for part in parts]
    if len(messages) > 5 and messages[1] not in ['-1', '1'] and '|' not in messages[1]:
        user_id, comment, user_nickname = messages[2], messages[1], messages[6]
        model_reply = filter_with_model(comment)  # Reply from model with comment.
        if(model_reply == 1):
            print(SEPARATOR)
            print(f"| {user_nickname}[{user_id}] - {comment}")

    else:
        # 채팅 뿐만 아니라 다른 메세지도 동시에 내려옵니다.
        # Not only chat messages, but other messages also come through at the same time.
        pass

In [None]:
# 바이트 크기 계산
# Calculate byte size.
def calculate_byte_size(string):
    return len(string.encode('utf-8')) + 6

In [None]:
# 채팅에 연결
# Connect to chat.
async def connect_to_chat_approved(url, ssl_context):
    try:
        BNO, BID = url.split('/')[-1], url.split('/')[-2]
        CHDOMAIN, CHATNO, FTK, TITLE, BJID, CHPT = get_player_live(BNO, BID)
        print(f"{SEPARATOR}\n"
              f"  CHDOMAIN: {CHDOMAIN}\n  CHATNO: {CHATNO}\n  FTK: {FTK}\n"
              f"  TITLE: {TITLE}\n  BJID: {BJID}\n  CHPT: {CHPT}\n"
              f"{SEPARATOR}")
    except Exception as e:
        # API call failure.
        print(f"  ERROR: API 호출 실패 - {e}")
        return

    try:
        async with websockets.connect(
            f"wss://{CHDOMAIN}:{CHPT}/Websocket/{BID}",
            subprotocols=['chat'],
            ssl=ssl_context,
            ping_interval=None
        ) as websocket:
            # 최초 연결시 전달하는 패킷
            # Packet sent during the initial connection.
            CONNECT_PACKET = f'{ESC}000100000600{F*3}16{F}'
            # 메세지를 내려받기 위해 보내는 패킷
            # Packet sent to receive messages.
            JOIN_PACKET = f'{ESC}0002{calculate_byte_size(CHATNO):06}00{F}{CHATNO}{F*5}'
            # 주기적으로 핑을 보내서 메세지를 계속 수신하는 패킷
            # Packet that periodically sends pings to keep receiving messages.
            PING_PACKET = f'{ESC}000000000100{F}'

            await websocket.send(CONNECT_PACKET)
            print(f"  연결 성공, 채팅방 정보 수신 대기중...")
            await asyncio.sleep(2)
            await websocket.send(JOIN_PACKET)

            async def ping():
                while True:
                    # 5분동안 핑이 보내지지 않으면 소켓은 끊어집니다.
                    # Disconnects if ping is not recieved for 5 minutes.
                    await asyncio.sleep(60)  # 1분 = 60초
                    await websocket.send(PING_PACKET)

            async def receive_messages_approved():
                while True:
                    data = await websocket.recv()
                    decode_message_approved(data)

            await asyncio.gather(
                receive_messages_approved(),
                ping(),
            )

    except Exception as e:
        # Web socket connection error.
        print(f"  ERROR: 웹소켓 연결 오류 - {e}")

# 4. Final Process

In [None]:
# Enter AfreecaTV URL.
url = input("아프리카TV URL을 입력해주세요: ")
ssl_context = create_ssl_context()
await connect_to_chat_approved(url, ssl_context)