In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='6'

import gradio as gr
from types import SimpleNamespace
import torch
from transformers import pipeline, AutoTokenizer, GPT2LMHeadModel
from langdetect import detect
from torch.nn import functional as F
from itertools import chain
import logging
import tensorflow as tf

# Set TensorFlow logging level to ERROR
logging.getLogger('tensorflow').setLevel(logging.ERROR)

import torch
import numpy as np
import argparse
import random


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#######################################################################################################
class Inferencer():
    def __init__(self, args):
        self.args = args
        if self.args.language == 'ko':
            #print("Loading Korean Chatbot...")
            self.args.model_path = self.args.model_path_ko
        elif self.args.language == 'en':
            #print("Loading English Chatbot...")
            self.args.model_path = self.args.model_path_en
        else:
            print("Not supported!")
        
        # Tokenizer & Vocab
        #print("Loading the tokenizer...")

        self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
        special_tokens = self.tokenizer.special_tokens_map
        self.args.bos_token = special_tokens['bos_token']
        self.args.eos_token = special_tokens['eos_token']
        self.args.sp1_token = special_tokens['additional_special_tokens'][0]
        self.args.sp2_token = special_tokens['additional_special_tokens'][1]

        vocab = self.tokenizer.get_vocab()
        self.vocab_size = len(vocab)
        self.args.bos_id = vocab[self.args.bos_token]
        self.args.eos_id = vocab[self.args.eos_token]
        self.args.sp1_id = vocab[self.args.sp1_token]
        self.args.sp2_id = vocab[self.args.sp2_token]
        
        # Load model    
        #print("Loading the model...")
        self.fix_seed(0)
        self.model = GPT2LMHeadModel.from_pretrained(self.args.model_path).to(device)
        self.args.max_len = self.model.config.n_ctx
        self.model.resize_token_embeddings(len(self.tokenizer))
              
        #print("Setting finished.")
              
    def infer(self, input, history):        
        self.model.eval()
        self.fix_seed(0)
        with torch.no_grad():         

            input_hists=[]

            for line in history:
                input_hists.append([self.args.sp1_id] + self.tokenizer.encode(line[0]))
                input_hists.append([self.args.sp2_id] + self.tokenizer.encode(line[1]))
            
            utter = input
            input_ids = [self.args.sp1_id] + self.tokenizer.encode(utter)
            input_hists.append(input_ids)                
            input_ids = [self.args.bos_id] + list(chain.from_iterable(input_hists)) + [self.args.sp2_id]
            start_sp_id = input_hists[0][0]
            next_sp_id = self.args.sp1_id if start_sp_id == self.args.sp2_id else self.args.sp2_id
            assert start_sp_id != next_sp_id
            token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
            assert len(token_type_ids) == len(input_hists)
            token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [self.args.sp2_id]
            assert len(input_ids) == len(token_type_ids)
            input_len = len(input_ids)
            
            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)
            
            output_ids = self.nucleus_sampling(input_ids, token_type_ids, input_len)                
            res = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            
            return res
                
    def nucleus_sampling(self, input_ids, token_type_ids, input_len):
        output_ids = []
        for pos in range(input_len, self.args.max_len):
            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, pos-1]  # (1, V)
            output = F.softmax(output, dim=-1)  # (1, V)
            
            sorted_probs, sorted_idxs = torch.sort(output, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
            idx_remove = cumsum_probs > self.args.top_p
            idx_remove[:, 1:] = idx_remove[:, :-1].clone()
            idx_remove[:, 0] = False
            sorted_probs[idx_remove] = 0.0
            sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)
            
            probs = torch.zeros(output.shape, device=device).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
            idx = torch.multinomial(probs, 1)  # (1, 1)
            
            idx_item = idx.squeeze(-1).squeeze(-1).item()
            output_ids.append(idx_item)
            
            if idx_item == self.args.eos_id:
                break
                
            input_ids = torch.cat((input_ids, idx), dim=-1)
            next_type_id = torch.LongTensor([[self.args.sp2_id]]).to(device)
            token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
            assert input_ids.shape == token_type_ids.shape
            
        return output_ids
    
    def fix_seed(self, seed):
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        random.seed(seed)



args = SimpleNamespace(
    seed = 0,
    language = 'ko', 
    model_path_ko = "path/to/korean/chat/model",
    model_path_en = "path/to/english/chat/model",
    top_p=0.8)

inferencer_kaist_ko=Inferencer(args)

args = SimpleNamespace(
    seed = 0,
    language = 'en', 
    model_path_ko = "path/to/korean/chat/model",
    model_path_en = "path/to/english/chat/model",
    top_p=0.8)

inferencer_kaist_en=Inferencer(args)



def chat_kaist_en(message, input_hists):
    max_turns = 4
    
    if len(input_hists) >= max_turns:
        num_exceeded = len(input_hists) - max_turns + 1
        input_hists = input_hists[num_exceeded:]
        
    reply = inferencer_kaist_en.infer(message, input_hists)
    return reply



def chat_kaist_ko(message, input_hists):
    max_turns = 4
    
    if len(input_hists) >= max_turns:
        num_exceeded = len(input_hists) - max_turns + 1
        input_hists = input_hists[num_exceeded:]
        
    reply = inferencer_kaist_ko.infer(message, input_hists)
    return reply



#######################################################################################################

theme = gr.themes.Base(
    primary_hue="indigo",
    secondary_hue=gr.themes.Color(c100="#ffffff", c200="#ffffff", c300="#d4d4d4", c400="#a3a3a3", c50="#ffffff", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#0f0f0f"),
    neutral_hue="slate",
).set(
    body_text_color='*neutral_950',
    body_text_color_subdued='*neutral_900'
)

#######################################################################################################

kaistchat_en_interface = gr.ChatInterface(
    fn=chat_kaist_en, 
    examples=["Hello!", "What is KI Building?", "When was KAIST established?"], 
    title="KAIST Chatbot (English)",
    theme=theme
)


kaistchat_ko_interface = gr.ChatInterface(
    fn=chat_kaist_ko, 
    examples=["안녕!", "KI 빌딩에 대해서 알려줘.", "KAIST는 언제 설립됐어?"], 
    title="KAIST Chatbot (Korean)",
    theme=theme
)



#######################################################################################################

app = gr.TabbedInterface(
    [kaistchat_en_interface, kaistchat_ko_interface],
    ['English','Korean'],
    theme=theme
)



  warn(
2024-08-26 03:24:56.672712: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
app.launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://eedd95dbd4089aae2b.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Traceback (most recent call last):
  File "/home/sslunder24/env/transformer/lib/python3.8/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
  File "/home/sslunder24/env/transformer/lib/python3.8/site-packages/gradio/route_utils.py", line 285, in call_process_api
    output = await app.get_blocks().process_api(
  File "/home/sslunder24/env/transformer/lib/python3.8/site-packages/gradio/blocks.py", line 1923, in process_api
    result = await self.call_function(
  File "/home/sslunder24/env/transformer/lib/python3.8/site-packages/gradio/blocks.py", line 1506, in call_function
    prediction = await fn(*processed_input)
  File "/home/sslunder24/env/transformer/lib/python3.8/site-packages/gradio/utils.py", line 785, in async_wrapper
    response = await f(*args, **kwargs)
  File "/home/sslunder24/env/transformer/lib/python3.8/site-packages/gradio/chat_interface.py", line 607, in _submit_fn
    response = await anyio.to_thread.