<a href="https://colab.research.google.com/github/ljg7234/BERT/blob/main/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

import math

In [2]:
class Attention(nn.Module):
    def forward(self,query,key,value,mask = None,dropout = None):
        scores = torch.matmul(query,key.transpose(-2,-1)) / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores,dim = -1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn,value),p_attn

In [3]:
class MultiHeadedAttention(nn.Module):
    def __init__(self,h,d_model,dropout = 0.1):
        super().__init__()
        assert d_model % h == 0

        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model,d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model,d_model)
        self.attention = Attention()
        self.dropout = nn.Dropout(p = dropout)

    def forward(self,query,key,value,mask = None):
        batch_size = query.size(0)
        query,key,value = [l(x).view(batch_size,-1,self.h,self.d_k).transpose(1,2)
                            for l,x in zip(self.linear_layers, (query,key,value))]

        x, attn = self.attention(query,key,value,mask = mask,dropout = self.dropout)

        x = x.transpose(1,2).contiguous().view(batch_size,-1,self.h * self.d_k)
        return self.output_linear(x)

In [4]:
class TokenEmbedding(nn.Embedding):
    def __init__(self,vocab_size,embed_size = 512):
        super().__init__(vocab_size,embed_size,padding_idx = 0)

In [5]:
class SegmentEmbedding(nn.Embedding):
    def __init__(self,embed_size = 512):
        super().__init__(3,embed_size,padding_idx= 0)

In [6]:
class PositionalEmbedding(nn.Module):
    def __init__(self,d_model,max_len = 512):
        super().__init__()

        pe = torch.zeros(max_len,d_model).float()
        pe.requires_grad = False

        position = torch.arange(0,max_len).float().unsqueeze(1)
        div_term = (torch.arange(0,d_model,2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)


        pe = pe.unsqueeze(0)
        self.register_buffer('pe',pe)

    def forward(self,x):
        return(self.pe[:,:x.size(1)])


In [7]:
class BERTEmbedding(nn.Module):
    def __init__(self,vocab_size,embed_size,dropout = 0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size = vocab_size,embed_size = embed_size)
        self.position = PositionalEmbedding(d_model = self.token.embedding_dim)
        self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
        self.dropout = nn.Dropout(p = dropout)
        self.embed_size = embed_size

    def forward(self,sequence,segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)

In [8]:
class GELU(nn.Module):
    def forward(self,x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x,3))))

In [9]:
class LayerNorm(nn.Module):
    def __init__(self,features,eps = 1e-6):
        super().__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self,x):
        mean = x.mean(-1,keepdim = True)
        std = x.std(-1,keepdim = True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [10]:
class SublayerConnection(nn.Module):
    def __init__(self,size,dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [11]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self,d_model,d_ff,dropout = 0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model,d_ff)
        self.w_2 = nn.Linear(d_ff,d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GELU()

    def forward(self,x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self,hidden,attn_heads,feed_forward_hidden,dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h = attn_heads, d_model = hidden)
        self.feed_forward = PositionwiseFeedForward(d_model = hidden, d_ff = feed_forward_hidden,dropout = dropout)
        self.input_sublayer = SublayerConnection(size = hidden, dropout = dropout)
        self.output_sublayer = SublayerConnection(size = hidden,dropout = dropout)
        self.dropout = nn.Dropout(p = dropout)

    def forward(self,x,mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x,_x,_x,mask = mask))
        x = self.output_sublayer(x,self.feed_forward)
        return self.dropout(x)

In [13]:
class BERT(nn.Module):
    def __init__(self,vocab_size,hidden = 768,n_layers = 12,attn_heads = 12,dropout = 0.1):
        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads

        self.feed_forward_hidden = hidden * 4
        self.embedding = BERTEmbedding(vocab_size = vocab_size,embed_size= hidden)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden,attn_heads,hidden * 4,dropout) for _ in range(n_layers)]
        )

    def forward(self,x,segment_info):
        mask = (x > 0).unsqueeze(1).repeat(1,x.size(1),1).unsqueeze(1)

        x = self.embedding(x,segment_info)

        for transformer in self.transformer_blocks:
            x = transformer.forward(x,mask)

        return x

In [14]:
class NextSentencePrediction(nn.Module):
    def __init__(self,hidden):
        super().__init__()
        self.linear = nn.Linear(hidden,hidden)
        self.activation = nn.Tanh()
        self.output = nn.Linear(hidden,2)
        self.softmax = nn.LogSoftmax(dim = -1)

    def forward(self,x):
        x = self.linear(x[:,0])
        x = self.activation(x)
        return self.softmax(self.output(x))

In [15]:
class MaskedLanguageModel(nn.Module):
    def __init__(self,hidden,vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden,vocab_size)
        self.softmax = nn.LogSoftmax(dim = -1)

    def forward(self,x):
        return self.softmax(self.linear(x))

In [16]:
class BERTLM(nn.Module):
    def __init__(self,bert : BERT,vocab_size):
        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.hidden)
        self.mask_lm = MaskedLanguageModel(self.bert.hidden,vocab_size)

    def forward(self,x,segment_label):
        x = self.bert(x,segment_label)
        return self.next_sentence(x),self.mask_lm(x)

In [17]:
import pickle
from tqdm.auto import tqdm
from collections import Counter

class TorchVocab(object):
    def __init__(self,counter,max_size = None,min_freq = 1,specials = ['<pad>','<oov>'],
                 vectors = None, unk_init = None, vectors_cache = None):
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq,1)

        self.itos = list(specials)
        for tok in specials:
            del counter[tok]

        max_size = None if max_size is None else max_size + len(self.itos)

        words_and_frequencies = sorted(counter.items(),key = lambda tup: tup[0])
        words_and_frequencies.sort(key = lambda tup: tup[1],reverse = True)

        for word,freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        self.stoi = {tok: i for i,tok in enumerate(self.itos)}

        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors,unk_init = unk_init,cache = vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def __eq__(self,other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def vocab_rerank(self):
        self.stoi = {word: i for i,word in enumerate(self.itos)}

    def extend(self,v,sort = False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) -1

In [18]:
class Vocab(TorchVocab):
    def __init__(self,counter,max_size = None,min_freq = 1):
        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4
        super().__init__(counter,specials = ["<pad>","<unk>","<eos>","<sos>","<mask>"],
                         max_size = max_size,min_freq = min_freq)

    def to_seq(self,sentence,seq_len,with_eos = False,with_sos = False) -> list:
        pass

    def from_seq(self,seq,join = False,with_pad = False):
        pass

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab':
        with open(vocab_path,"rb") as f:
            return pickle.load(f)

    def save_vocab(self,vocab_path):
        with open(vocab_path,"wb") as f:
            pickle.dump(self,f)


In [19]:
class WordVocab(Vocab):
    def __init__(self,texts,max_size = None,min_freq = 1):
        print("Building Vocab")
        counter = Counter()
        for line in tqdm(texts):
            if isinstance(line,list):
                words = line
            else:
                words = line.replace("\n", "").split()
            for word in words:
                counter[word] += 1
        super().__init__(counter,max_size = max_size,min_freq= min_freq)

    def to_seq(self,sentence,seq_len = None,with_eos = False,with_sos = False,with_len = False):
        if isinstance(sentence,str):
            sentence = sentence.split()

        seq = [self.stoi.get(word,self.unk_index) for word in sentence]

        if with_eos:
            seq += [self.eos_index]
        if with_sos:
            seq = [self.sos_index] + seq

        origin_seq_len = len(seq)

        if seq_len is None:
            pass
        elif len(seq) <= seq_len:
            seq += [self.pad_index for _ in range(seq_len - len(seq))]
        else:
            seq = seq[:seq_len]

        return (seq,origin_seq_len) if with_len else seq

    def from_seq(self,seq,join = False,with_pad = False):
        words = [self.itos[idx]
                 if idx < len(self.itos)
                 else "<%d>" %idx
                 for idx in seq
                 if not with_pad or idx != self.pad_index]
        return " ".join(words) if join else words

    @staticmethod
    def load_vocab(vocab_path: str) -> 'WordVocab':
        with open(vocab_path,"rb") as f:
            return pickle.load(f)

In [20]:
from torch.utils.data import Dataset
import random

class BERTDataset(Dataset):
    def __init__(self,corpus_path,vocab,seq_len,encoding = "utf-8",corpus_lines = None,on_memory=True):
        self.vocab = vocab
        self.seq_len = seq_len

        self.on_memory = on_memory
        self.corpus_lines = corpus_lines
        self.corpus_path = corpus_path
        self.encoding = encoding
        self.vocab_unk_index = vocab.unk_index

        with open(corpus_path,"r",encoding = encoding) as f:
            if on_memory:
                self.lines = [line.strip().split("\t")
                                for line in f
                              if "\t" in line]
                self.corpus_lines = len(self.lines)
            else:
                self.corpus_lines = 0
                for line in f:
                    if line.strip(): self.corpus_lines += 1
        if self.corpus_lines == 0:
            raise ValueError("데이터셋이 비어있습니다. corpus.txt 파일을 확인하세요 ")

        if not on_memory:
            self.file = open(corpus_path, "r", encoding=encoding)
            self.random_file = open(corpus_path, "r", encoding=encoding)
            for _ in range(random.randint(0, self.corpus_lines - 1)):
                next(self.random_file)
    def __len__(self):
        return self.corpus_lines

    def random_word(self,sentence):
        tokens = sentence.split()
        output_label = []

        for i,token in enumerate(tokens):
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15
                if prob < 0.8:
                    tokens[i] = self.vocab.mask_index
                elif prob < 0.9:
                    tokens[i] = random.randrange(len(self.vocab))
                else:
                    tokens[i] = self.vocab.stoi.get(token,self.vocab.unk_index)

                output_label.append(self.vocab.stoi.get(token,self.vocab.unk_index))
            else:
                tokens[i] = self.vocab.stoi.get(token,self.vocab.unk_index)
                output_label.append(0)

        return tokens,output_label

    def get_corpus_line(self,item):
        if self.on_memory:
            return self.lines[item][0],self.lines[item][1]
        else:
            line = self.file.__next__()
            if line is None:
                self.file.close()
                self.file = open(self.corpus_path,"r",encoding = self.encoding)
                line = self.file.__next__()
            t1,t2 = line[:-1].split("\t")
            return t1,t2

    def get_random_line(self):
        if self.on_memory:
            return self.lines[random.randrange(len(self.lines))][1]

        line = self.file.__next__()
        if line is None:
            self.file.close()
            self.file = open(self.corpus_path,"r",encoding = self.encoding)
            for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()
            line = self.random_file.__next__()
        return line[:-1].split("\t")[1]

    def random_sent(self,index):
        t1,t2 = self.get_corpus_line(index)

        if random.random() > 0.5:
            return t1,t2,0
        else:
            return t1,self.get_random_line(),1

    def __getitem__(self, item):
        t1, t2, is_next_label = self.random_sent(item)
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
        t2 = t2_random + [self.vocab.eos_index]

        t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
        t2_label = t2_label + [self.vocab.pad_index]

        while len(t1) + len(t2) > self.seq_len:
            if len(t1) > len(t2):
                t1.pop()
                t1_label.pop()
            else:
                t2.pop()
                t2_label.pop()

        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        t1_len = len(t1)
        segment_label = ([1] * t1_len + [2] * (len(bert_input) - t1_len))[:self.seq_len]


        padding = [self.vocab.pad_index] * (self.seq_len - len(bert_input))
        bert_input.extend(padding)
        bert_label.extend(padding)
        segment_label.extend(padding)

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()}

In [21]:
from torch.optim import Adam
from torch.utils.data import DataLoader

class BERTTrainer:
    def __init__(self,bert: BERT,vocab_size: int,
                 train_dataloader: DataLoader,test_dataloader: DataLoader = None,
                 lr:float = 1e-4, betas = (0.9,0.999),weight_decay: float = 0.01, warmup_steps = 10,
                 with_cuda = True,cuda_devices = None,log_freq: int = 10):
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        self.bert = bert
        self.model = BERTLM(bert,vocab_size).to(self.device)

        if with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model,device_ids = cuda_devices)

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.optim = Adam(self.model.parameters(),lr = lr,betas = betas,weight_decay=weight_decay)

        self.criterion_mlm = nn.NLLLoss(ignore_index = 0)
        self.criterion_nsp = nn.NLLLoss()

        self.log_freq = log_freq

        print("Total Parameters:",sum([p.nelement() for p in self.model.parameters()]))

    def test(self,epoch):
        self.iteration(epoch,self.test_data,train = False)

    def iteration(self,epoch,data_loader,train = True):
        str_code = "train" if train else "test"

        data_iter = tqdm(enumerate(data_loader),
                         total = len(data_loader),
                         disable = True)

        avg_loss = 0.0
        total_correct = 0
        total_element = 0
        avg_next_loss = 0.0
        avg_mask_loss = 0.0

        for i,data in data_iter:
            data = {key: value.to(self.device) for key,value in data.items()}
            next_sent_output,mask_lm_output = self.model.forward(data["bert_input"],data["segment_label"])
            next_loss = self.criterion_nsp(next_sent_output,data["is_next"])
            mask_loss = self.criterion_mlm(mask_lm_output.transpose(1,2),data["bert_label"])
            loss = next_loss + mask_loss

            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

            correct = next_sent_output.argmax(dim = -1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            avg_next_loss += next_loss.item()
            avg_mask_loss += mask_loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

        final_loss = avg_loss / len(data_loader)
        final_acc = total_correct * 100.0 / total_element
        final_next_loss = avg_next_loss / len(data_loader)
        final_mask_loss = avg_mask_loss / len(data_loader)
        return final_loss, final_next_loss,final_mask_loss,final_acc

    def save(self,epoch,file_path = "output/bert_trained.model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.model,output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:"  %epoch,output_path)
        return output_path


In [22]:
import os
import shutil
from collections import Counter
from datasets import load_dataset
import re

train_dataset_path = "corpus.txt"
vocab_path = "vocab.txt"
output_dir = "output"

if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)

toy_sentences = [
    "a cat sits on a mat .",
    "that mat is on the floor .",
    "one dog runs in a park .",
    "this park is very big .",
    "i like to eat apples .",
    "apples are very sweet .",
    "she reads a thick book .",
    "this book has many pages .",
    "he drives a fast car .",
    "that car is very expensive .",
    "birds fly in a sky .",
    "sky is blue and clear .",
    "fish swim in water .",
    "water is cold and deep .",
    "a chef cooks a meal .",
    "some meal is on a table .",
    "sun rises in a morning .",
    "morning is bright and warm .",
    "rain falls from clouds .",
    "clouds are dark and heavy .",
    "flowers grow in a garden .",
    "garden is full of colors .",
    "a teacher speaks to a class .",
    "every class is listening now .",
    "moon shines in a night .",
    "night is quiet and dark .",
    "a boy plays a guitar .",
    "guitar sounds very loud .",
    "we learn deep learning .",
    "learning is very interesting ."
]

formatted_pairs = []
for _ in range(300):
    for i in range(len(toy_sentences) - 1):
        formatted_pairs.append(f"{toy_sentences[i]}\t{toy_sentences[i+1]}")

with open("corpus.txt", "w", encoding="utf-8") as f:
    f.write("\n".join(formatted_pairs))

with open("corpus.txt", "r", encoding="utf-8") as f:
    vocab = WordVocab(f, min_freq=1)

vocab.save_vocab("vocab.txt")
print(f"토이 데이터 생성 완료! 문장 쌍: {len(formatted_pairs)}개, 사전 크기: {len(vocab)}")


Building Vocab


0it [00:00, ?it/s]

토이 데이터 생성 완료! 문장 쌍: 8700개, 사전 크기: 94


In [23]:
hidden = 32
layers = 2
attn_heads = 2
seq_len = 20
batch_size = 16
epochs = 100
lr = 1e-3

train_dataset = BERTDataset(train_dataset_path,vocab,seq_len = seq_len)
train_data_loader = DataLoader(train_dataset,batch_size = batch_size,num_workers = 2)

In [24]:
print("Building BERT model...")
bert = BERT(len(vocab),hidden = hidden, n_layers = layers,attn_heads = attn_heads)

trainer = BERTTrainer(bert,len(vocab),train_dataloader = train_data_loader,lr = lr,with_cuda = True)

print("Training Start...")
for epoch in range(epochs):
    loss,next_l,mask_l,acc= trainer.iteration(epoch,train_data_loader,train = True)

    if(epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d} | Total: {loss:.4f} | NSP: {next_l:.4f} | MLM: {mask_l:.4f} | Acc: {acc:.2f}%")
        trainer.save(epoch, "output/bert.model")


Building BERT model...
Total Parameters: 32736
Training Start...
Epoch   5 | Total: 3.6833 | NSP: 0.6936 | MLM: 2.9896 | Acc: 50.01%
EP:4 Model Saved on: output/bert.model.ep4
Epoch  10 | Total: 3.2874 | NSP: 0.6933 | MLM: 2.5941 | Acc: 49.64%
EP:9 Model Saved on: output/bert.model.ep9
Epoch  15 | Total: 3.1986 | NSP: 0.6927 | MLM: 2.5059 | Acc: 51.05%
EP:14 Model Saved on: output/bert.model.ep14
Epoch  20 | Total: 3.1284 | NSP: 0.6919 | MLM: 2.4364 | Acc: 50.56%
EP:19 Model Saved on: output/bert.model.ep19
Epoch  25 | Total: 3.0750 | NSP: 0.6919 | MLM: 2.3830 | Acc: 51.23%
EP:24 Model Saved on: output/bert.model.ep24
Epoch  30 | Total: 3.0009 | NSP: 0.6918 | MLM: 2.3092 | Acc: 52.10%
EP:29 Model Saved on: output/bert.model.ep29
Epoch  35 | Total: 2.9998 | NSP: 0.6925 | MLM: 2.3073 | Acc: 52.21%
EP:34 Model Saved on: output/bert.model.ep34
Epoch  40 | Total: 2.9738 | NSP: 0.6924 | MLM: 2.2815 | Acc: 51.31%
EP:39 Model Saved on: output/bert.model.ep39
Epoch  45 | Total: 2.9806 | NSP: 0.

In [28]:
model_path = "output/bert.model.ep99"
vocab_path = "vocab.txt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(vocab_path, "rb") as f:
    vocab = pickle.load(f)

model = torch.load(model_path, map_location=device, weights_only=False)
model.to(device)
model.eval()


text_a = "birds fly in a sky ."
text_b = "sky is <mask style> and clear ."

tokens_a = text_a.split()
tokens_b = ["sky", "is", "<mask style>", "and", "clear", "."]

input_ids = [vocab.sos_index] + \
            [vocab.stoi.get(t, vocab.unk_index) for t in tokens_a] + \
            [vocab.eos_index] + \
            [vocab.stoi.get(t, vocab.unk_index) if t != "<mask style>" else vocab.mask_index for t in tokens_b] + \
            [vocab.eos_index]

segment_label = [1] * (len(tokens_a) + 2) + [2] * (len(tokens_b) + 1)

input_tensor = torch.tensor([input_ids]).to(device)
segment_tensor = torch.tensor([segment_label]).to(device)

try:
    mask_pos = input_ids.index(vocab.mask_index)
    with torch.no_grad():
        nsp_output, mlm_output = model(input_tensor, segment_tensor)

        is_next = nsp_output.argmax(dim=-1).item()
        print(f"--- 테스트 결과 ---")
        print(f"입력 문장 A: {text_a}")
        print(f"입력 문장 B: {text_b}")
        print(f"문장 관계 예측: {'연속됨(IsNext)' if is_next == 0 else '상관없음(NotNext)'}")

        predict_id = mlm_output[0, mask_pos].argmax(dim=-1).item()
        print(f"단어 예측 결과: {vocab.itos[predict_id]}")

except ValueError:
    print("입력 데이터에 <mask style> 토큰이 없습니다.")

--- 테스트 결과 ---
입력 문장 A: birds fly in a sky .
입력 문장 B: sky is <mask style> and clear .
문장 관계 예측: 연속됨(IsNext)
단어 예측 결과: blue
