In [None]:
import os
import re
import shutil
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

In [None]:
from datasets import load_dataset
dataset = load_dataset("100suping/korean_unlabeled_web_text", split="train")

max_samples=10000
dataset=dataset.select(range(min(len(dataset),max_samples)))

def clean_text(text):
    text = re.sub(r"[^ㄱ-ㅎㅏ-ㅣ가-힣 ]", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text

output_file='korean_webtext_cleaned.txt'
with open(output_file, 'w') as f:
    for example in dataset:
        text = clean_text(example['text'])
        f.write(text + '\n')
print(f"Saved cleaned text to {output_file}")

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('LGAI-EXAONE-3.5-7.8B-Instruct')
VOCAB_SIZE = len(tokenizer)
print(f"Vocab size: {VOCAB_SIZE}")

In [None]:
class MyDataset(Dataset):
    def __init__(self,txt,max_len,stride):
        token_ids=tokenizer(txt)
        self.inputs_ids=[]
        self.target_ids=[]
        for i in range(0,len(token_ids)-max_len,stride):
            self.inputs_ids.append(torch.tensor(token_ids[i:i+max_len]))
            self.target_ids.append(torch.tensor(token_ids[i+1:i+max_len+1]))
    def __len__(self):
        return len(self.inputs_ids)
    def __getitem__(self,idx):
        return self.inputs_ids[idx],self.target_ids[idx]

In [None]:
class RotaryEmbedding(nn.Module):
    def __init__(self,head_dim,base=10000,max_seq_len=2048):
        super().__init__()
        self.dim=head_dim//2
        theta=1.0/(base**(torch.arange(0,self.dim,2).float()/self.dim))
        seq_pos=torch.arange(max_seq_len,dtype=torch.float)
        freqs=torch.einsum("i,j->ij",seq_pos,theta)

        self.register_buffer('sin_table',torch.sin(freqs),persistent=False)
        self.register_buffer('cos_table',torch.cos(freqs),persistent=False)
    def forward(self,x):
        seq_len=x.size(2)
        return self.apply_rope(x,seq_len,start_pos)
    def apply_rope(self,x,seq_len,start_pos=0):
        x_rope,x_pass=x.split(self.dim,dim=-1)
        x1,x2=x_rope.chunk(2,dim=-1)
        sin_table=self.sin_table[start_pos:start_pos+seq_len].unsqueeze(0).unsqueeze(0)
        cos_table=self.cos_table[start_pos:start_pos+seq_len].unsqueeze(0).unsqueeze(0)
        x1_rot=x1*cos_table-x2*sin_table
        x2_rot=x1*sin_table+x2*cos_table
        x_rope=torch.cat((x1_rot,x2_rot),dim=-1)
        return torch.cat((x_rope,x_pass),dim=-1)

In [None]:
class GQA(nn.Module):
    def __init__(self,d_in,d_out,n_query_heads,n_kv_heads,max_seq_len):
        super().__init__()
        self.n_query_heads=n_query_heads
        self.n_kv_heads=n_kv_heads
        self.n_rep=n_kv_heads//n_query_heads
        self.head_dim=d_out//n_query_heads
        self.d_out=d_out
        self.q_proj=nn.Linear(d_in,d_out,bias=False)
        self.k_proj=nn.Linear(d_in,d_out,bias=False)
        self.v_proj=nn.Linear(d_in,d_out,bias=False)
        self.o_proj=nn.Linear(d_out,d_out,bias=True)
        self.dropout=nn.Dropout(0.1)

        causal_mask=torch.triu(torch.ones(max_seq_len,max_seq_len),diagnal=1).bool()
        self.register_buffer('causal_mask',causal_mask)

        self.rope=RotaryEmbedding(self.head_dim,max_seq_len=max_seq_len)
    
    def forward(self,x):
        b,seq_len,_=x.shape
        q=self.q_proj(x)
        k=self.k_proj(x)
        v=self.v_proj(x)

        q=q.view(b,seq_len,self.n_query_heads,self.head_dim).transpose(1,2)
        k=k.view(b,seq_len,self.n_kv_heads,self.head_dim).transpose(1,2)
        v=v.view(b,seq_len,self.n_kv_heads,self.head_dim).transpose(1,2)

        if past_kv is not None:
            past_k,past_v=past_kv
            k=torch.cat((past_k,k),dim=2)
            v=torch.cat((past_v,v),dim=2)

        kv_seq_len=k.size(2)
        q_seq_len=q.size(2)

        q=self.rope(q,start_pos=0)
        k=self.rope(k,start_pos=0)

        attn_scores=torch.zeros(b,self.n_query_heads,q_seq_len,kv_seq_len).to(x.device)
        for qh in range(self.n_query_heads):
            kv_head_idx=qh//self.n_rep
            scores=torch.matmul(q[:, qh:qh+1],k[:, kv_head_idx:kv_head_idx+1].transpose(-2,-1))
            attn_scores[:, qh:qh+1] = scores / math.sqrt(self.head_dim)
        causal_mask_slice=self.mask[:q_seq_len,:kv_seq_len].unsqueeze(0).unsqueeze(0)
        attn_scores=attn_scores.masked_fill(causal_mask_slice==0,float('-inf'))

        attn_weights=F.softmax(attn_scores,dim=-1)
        attn_weights=self.dropout(attn_weights)

        context=torch.zeros(b,self.n_query_heads,q_seq_len,self.head_dim).to(x.device)
        for qh in range(self.n_query_heads):
            kv_head_idx=qh//self.n_rep
            context[:,qh:qh+1]=torch.matmul(attn_weights[:,qh:qh+1],v[:,kv_head_idx:kv_head_idx+1])
        context=context.transpose(1,2).contiguous().view(b,seq_len,self.d_out)
        
        out=self.o_proj(context)
        if use_cache:
            return out,(k,v)
        else:
            return out,(None,None)

In [None]:
class RMSNorm(nn.Module):
    def __init__(self,dim,eps=1e-6):
        super().__init__()
        self.eps=eps
        self.weight=nn.Parameter(torch.ones(dim))
    def forward(self,x):
        normed=x*torch.rsqrt(x.pow(2).mean(dim=-1,keepdim=True)+self.eps)
        return self.weight*normed
class SwiGLU(nn.Module):
    def forward(self,x1,x2):
        return F.silu(x1)*x2
class MLP(nn.Module):
    def __init__(self,emb_dim,expansion_factor=4):
        super().__init__()
        self.inner_dim=emb_dim*expansion_factor
        self.fc1=nn,Linear(emb_dim,2*self.inner_dim,bias=True)
        self.act=SwiGLU()
        self.fc2=nn.Linear(self.inner_dim,emb_dim,bias=True)
        self.dropout=nn.Dropout(0.1)
    def forward(self,x):
        x=self.fc1(x)
        x1,x2=x.chunk(2,dim=-1)
        x=self.act(x1,x2)
        x=self.fc2(x)
        x=self.dropout(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self,emb_dim,n_query_heads,n_kv_heads,max_seq_len):
        super().__init__()
        self.norm1=RMSNorm(emb_dim)
        self.att=GQA(emb_dim,emb_dim,n_query_heads,n_kv_heads,max_seq_len)
        self.att_drop=nn.Dropout(0.1)
        self.norm2=RMSNorm(emb_dim)
        self.mlp=MLP(emb_dim)
        self.mlp_drop=nn.Dropout(0.1)
    def forward(self,x):
        hidden=self.norm1(x)
        attn_out,(k,v)=self.att(hidden,past_kv=past_kv,use_cache=use_cache)
        attn_out=self.att_drop(attn_out)
        x=x+attn_out
        hidden=self.norm2(x)
        mlp_out=self.mlp(hidden)
        mlp_out=self.mlp_drop(mlp_out)
        x=x+mlp_out
        return x,(k,v)

In [None]:
class BigdefenceModel(nn.Module):
    def __init__(self,vocab_size=VOCAB_SIZE,emb_dim=1536,num_layers=12,n_query_heads=12,n_kv_heads=4,
                 max_seq_len=2048):
        super().__init__()
        self.vocab_size=vocab_size
        self.emb_dim=emb_dim
        self.max_seq_len=max_seq_len
        self.num_layers=num_layers

        self.tok_emb=nn.Embedding(vocab_size,emb_dim)
        self.drop_emb=nn.Dropout(0.1)

        self.blocks=nn.ModuleList([
            TransformerBlock(emb_dim,n_query_heads,n_kv_heads,max_seq_len) for _ in range(num_layers)
        ])

        self.final_norm=RMSNorm(emb_dim)
        self.lm_head=nn.Linear(emb_dim,vocab_size,bias=False)
    def forward(self,x,past_kv=None,use_cache=False):
        b,seq_len=x.shape
        x=self.tok_emb(x)
        x=self.drop_emb(x)
        new_past_kv=[]
        for i,block in enumerate(self.blocks):
            past=None
            if past_kv is not None:
                past=past_kv[i]
            x, (k,v) = block(x,past_kv=past,use_cache=use_cache)
            new_past_kv.append((k,v))
        x=self.final_norm(x)
        logits=self.lm_head(x)
        if use_cache:
            return logits,new_past_kv
        else:
            return logits

In [None]:
def train_model():
    txt_file='korean_webtext_cleaned.txt'
    with open(txt_file, 'r', encoding='utf-8') as f:
        txt=f.read()
    dataset=MyDataset(txt,max_len=2048,stride=128)
    train_loader=DataLoader(dataset,batch_size=8,shuffle=True)
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model=BigdefenceModel(
        vocab_size=VOCAB_SIZE,
        emb_dim=1536,
        num_layers=12,
        n_query_heads=12,
        n_kv_heads=4,
        max_seq_len=2048
    ).to(device)
    optimizer=torch.optim.AdamW(model.parameters(),lr=1e-4)
    scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=len(train_loader)*2)
    criterion=nn.CrossEntropyLoss()
    EPOCHS=1
    global_step=0
    scaler=GradScaler()
    for epoch in range(EPOCHS):
        model.train()
        running_loss=0.0
        for batch_idx, (input_batch,target_batch) in enumerate(train_loader):
            input_batch=input_batch.to(device)
            target_batch=target_batch.to(device)
            optimizer.zero_grad()
            with autocast():
                logits=model(input_batch)
                loss=criterion(logits.view(-1,logits.size(-1)),target_batch.view(-1))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            running_loss+=loss.item()
            global_step+=1
            if global_step%100==0:
                print(f"Epoch {epoch+1}/{EPOCHS}, Step {global_step}, Loss: {running_loss/100}")
                running_loss=0.0
        torch.save(model.state_dict(),f"model_{epoch+1}.pt")
    return model
      

In [None]:
def generate_with_cache(
        model,idx,max_new_tokens=256,temperature=0.8,top_k=40,eos_id=None
):
    model.eval()
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    idx=idx.to(device)
    past_kv=None
    with torch.no_grad():
        for step in range(max_new_tokens):
            logits,new_past_kv=model(idx[:,-1:],past_kv=past_kv,use_cache=True)
            logits_last=logits[:,-1,:]
            if top_k is not None:
                v,_=torch.topk(logits_last,top_k)
                threshold=v[:,-1]
                logits_last[logits_last<threshold]=float('-inf')
            probs=F.softmax(logits_last/temperature,dim=-1)
            next_token=torch.multinomial(probs,num_samples=1)
            if eos_id is not None and next_token.item()==eos_id:
                break
            idx=torch.cat((idx,next_token),dim=1)
            past_kv=new_past_kv
    return idx

In [None]:
def run_chatbot(model):
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("간단한 챗봇을 시작합니다. 종료하려면 'exit'를 입력하세요.")
    while True:
        user_input=input("User > ")
        if user_input=='exit':
            break
        tokens=tokenizer(user_input)
        input_ids=torch.tensor(tokens,dtype=torch.long).unsqueeze(0).to(device)
        output_ids=generate_with_cache(
            model,
            input_ids,
            max_new_tokens=256,
            temperature=0.8,
            top_k=40,
            eos_id=None
        )
        response=tokenizer.decode(output_ids[0].tolist())
        print(f"Bot > {response}")

In [None]:
model=train_model()

In [None]:
run_chatbot(model)