In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import datasets
from datasets import load_dataset
from transformers import BertTokenizer, DataCollatorForLanguageModeling
import numpy as np
from random import randrange, randint
import random
from torch.utils.data import DataLoader


In [2]:
datasets.disable_caching()

In [3]:
# 데이터셋 로드
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train[:1%]')

train_test_split = dataset.train_test_split(test_size=0.2)
train_set = train_test_split['train']
test_set = train_test_split['test']


In [5]:
# 토크나이즈
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenized_train_set = train_set.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_test_set = test_set.map(tokenize_function, batched=True, remove_columns=["text"])

Map:   0%|          | 0/293 [00:00<?, ? examples/s]

Map:   0%|          | 0/74 [00:00<?, ? examples/s]

In [6]:
# 데이터셋 생성
def create_mlm_and_nsp_dataset(dataset, tokenizer, max_length=512, nsp_probability=0.5):
    examples = []
    for i in range(len(dataset['input_ids']) - 1):
        # Next Sentence Prediction (NSP)
        # 두 문장이 실제로 이어지는지 여부는 50% 비율로 참인 문장과 랜덤하게 추출되어 거짓인 문장의 비율로 구성
        if random.random() < nsp_probability:
            is_next = 1
            next_sentence = dataset['input_ids'][i + 1]
        else:
            is_next = 0
            next_sentence = random.choice(dataset['input_ids'])
        
        # current + next sentence
        input_ids = dataset['input_ids'][i] + next_sentence
        attention_mask = dataset['attention_mask'][i] + dataset['attention_mask'][i + 1]
        token_type_ids = [0] * len(dataset['input_ids'][i]) + [1] * len(next_sentence)

        input_ids = input_ids[:max_length]
        attention_mask = attention_mask[:max_length]
        token_type_ids = token_type_ids[:max_length]

        # Masked Language Modeling (MLM)
        # 마스킹은 전체 단어의 15% 정도만 진행
        # 모든 토큰을 마스킹 하는게 아니라 80% 정도만 <MASK>로 처리
        # 10%는 랜덤한 단어, 나머지 10%는 정상적인 단어를 그대로 둔다.
        labels = input_ids.copy()
        for j in range(len(input_ids)):
            if random.random() < 0.15:
                if random.random() < 0.8:
                    input_ids[j] = tokenizer.mask_token_id
                elif random.random() < 0.5:
                    input_ids[j] = random.randint(0, tokenizer.vocab_size - 1)
            else:
                labels[j] = -100

        examples.append({
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'labels': labels,
            'next_sentence_label': is_next
        })

    return examples

mlm_nsp_train_set = create_mlm_and_nsp_dataset(tokenized_train_set, tokenizer)
mlm_nsp_test_set = create_mlm_and_nsp_dataset(tokenized_test_set, tokenizer)


In [7]:
class BertDataset:
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return {key: np.array(value) for key, value in self.examples[idx].items()}

# 데이터셋 생성
train_dataset = BertDataset(mlm_nsp_train_set)
test_dataset = BertDataset(mlm_nsp_test_set)

# 데이터 로더 생성
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)


In [8]:
class BertEmbeddings(nn.Module):
    vocab_size: int
    hidden_size: int
    max_length: int
    type_vocab_size: int

    def setup(self):
        self.word_embeddings = nn.Embed(self.vocab_size, self.hidden_size)
        self.position_embeddings = nn.Embed(self.max_length, self.hidden_size)
        self.token_type_embeddings = nn.Embed(self.type_vocab_size, self.hidden_size)
        self.LayerNorm = nn.LayerNorm()
        self.dropout = nn.Dropout(0.1, deterministic=False)

    def __call__(self, input_ids, token_type_ids):
        seq_length = input_ids.shape[1]
        position_ids = jnp.arange(seq_length)[None, :]
        word_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = word_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [9]:
class BertSelfAttention(nn.Module):
    num_heads: int
    hidden_size: int

    def setup(self):
        assert self.hidden_size % self.num_heads == 0
        self.attention_head_size = int(self.hidden_size / self.num_heads)
        self.all_head_size = self.num_heads * self.attention_head_size

        self.query = nn.Dense(self.all_head_size)
        self.key = nn.Dense(self.all_head_size)
        self.value = nn.Dense(self.all_head_size)

        self.dropout = nn.Dropout(0.1, deterministic=False)

    def transpose_for_scores(self, x):
        new_x_shape = x.shape[:-1] + (self.num_heads, self.attention_head_size)
        x = x.reshape(new_x_shape)
        return x.transpose((0, 2, 1, 3))

    def __call__(self, hidden_states, attention_mask):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = jnp.einsum("...qhd,...khd->...hqk", query_layer, key_layer)
        attention_scores = attention_scores / jnp.sqrt(self.attention_head_size)
        
        if attention_mask is not None:
            attention_mask = attention_mask[:, :, None, None]
            attention_scores = attention_scores + attention_mask

        attention_probs = nn.softmax(attention_scores, axis=-1)
        attention_probs = self.dropout(attention_probs)

        context_layer = jnp.einsum("...hqk,...khd->...qhd", attention_probs, value_layer)
        context_layer = context_layer.transpose((0, 2, 1, 3))
        new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,)
        context_layer = context_layer.reshape(new_context_layer_shape)
        return context_layer

In [10]:
class BertSelfOutput(nn.Module):
    hidden_size: int

    def setup(self):
        self.dense = nn.Dense(self.hidden_size)
        self.LayerNorm = nn.LayerNorm()
        self.dropout = nn.Dropout(0.1, deterministic=False)

    def __call__(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

In [11]:
class BertAttention(nn.Module):
    num_heads: int
    hidden_size: int

    def setup(self):
        self.self = BertSelfAttention(self.num_heads, self.hidden_size)
        self.output = BertSelfOutput(self.hidden_size)

    def __call__(self, hidden_states, attention_mask):
        self_outputs = self.self(hidden_states, attention_mask)
        attention_output = self.output(self_outputs, hidden_states)
        return attention_output

In [12]:
class BertIntermediate(nn.Module):
    hidden_size: int
    intermediate_size: int

    def setup(self):
        self.dense = nn.Dense(self.intermediate_size)
        self.intermediate_act_fn = nn.gelu

    def __call__(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

In [13]:
class BertOutput(nn.Module):
    hidden_size: int
    intermediate_size: int

    def setup(self):
        self.dense = nn.Dense(self.hidden_size)
        self.LayerNorm = nn.LayerNorm()
        self.dropout = nn.Dropout(0.1, deterministic=False)

    def __call__(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

In [14]:
class BertLayer(nn.Module):
    num_heads: int
    hidden_size: int
    intermediate_size: int

    def setup(self):
        self.attention = BertAttention(self.num_heads, self.hidden_size)
        self.intermediate = BertIntermediate(self.hidden_size, self.intermediate_size)
        self.output = BertOutput(self.hidden_size, self.intermediate_size)

    def __call__(self, hidden_states, attention_mask):
        attention_output = self.attention(hidden_states, attention_mask)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

In [15]:
class BertEncoder(nn.Module):
    hidden_size: int
    num_heads: int
    num_layers: int
    intermediate_size: int

    def setup(self):
        self.layers = [BertLayer(self.num_heads, self.hidden_size, self.intermediate_size) for _ in range(self.num_layers)]

    def __call__(self, hidden_states, attention_mask):
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)  # Simplified
        return hidden_states

In [16]:
class BertPooler(nn.Module):
    hidden_size: int

    def setup(self):
        self.dense = nn.Dense(self.hidden_size)
        self.activation = nn.tanh

    def __call__(self, hidden_states):
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [17]:
class BertForPreTraining(nn.Module):
    vocab_size: int
    hidden_size: int
    max_length: int
    num_heads: int
    num_layers: int
    intermediate_size: int
    type_vocab_size: int

    def setup(self):
        self.embeddings = BertEmbeddings(self.vocab_size, self.hidden_size, self.max_length, self.type_vocab_size)
        self.encoder = BertEncoder(self.hidden_size, self.num_heads, self.num_layers, self.intermediate_size)
        self.pooler = BertPooler(self.hidden_size)
        self.cls = nn.Dense(self.vocab_size)
        self.seq_relationship = nn.Dense(2)

    def __call__(self, input_ids, attention_mask, token_type_ids):
        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoder_output = self.encoder(embedding_output, attention_mask)
        pooled_output = self.pooler(encoder_output)
        prediction_scores = self.cls(encoder_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score

In [18]:
class BertForPreTraining(nn.Module):
    vocab_size: int
    hidden_size: int
    max_length: int
    num_heads: int
    num_layers: int
    intermediate_size: int
    type_vocab_size: int

    def setup(self):
        self.embeddings = BertEmbeddings(self.vocab_size, self.hidden_size, self.max_length, self.type_vocab_size)
        self.encoder = BertEncoder(self.hidden_size, self.num_heads, self.num_layers, self.intermediate_size)
        self.pooler = BertPooler(self.hidden_size)
        self.cls = nn.Dense(self.vocab_size)
        self.seq_relationship = nn.Dense(2)

    def __call__(self, input_ids, attention_mask, token_type_ids):
        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoder_output = self.encoder(embedding_output, attention_mask)
        pooled_output = self.pooler(encoder_output)
        prediction_scores = self.cls(encoder_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score

In [24]:
def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones((1, 512), jnp.int32), jnp.ones((1, 512), jnp.int32), jnp.zeros((1, 512), jnp.int32))
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


In [20]:
# Loss 함수 정의
def compute_loss(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels):
    masked_lm_loss = optax.softmax_cross_entropy(prediction_scores, jax.nn.one_hot(masked_lm_labels, num_classes=30522))
    next_sentence_loss = optax.softmax_cross_entropy(seq_relationship_score, jax.nn.one_hot(next_sentence_labels, num_classes=2))
    return jnp.mean(masked_lm_loss) + jnp.mean(next_sentence_loss)

# 평가 메트릭 정의
def compute_metrics(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels):
    masked_lm_accuracy = jnp.mean(jnp.argmax(prediction_scores, -1) == masked_lm_labels)
    next_sentence_accuracy = jnp.mean(jnp.argmax(seq_relationship_score, -1) == next_sentence_labels)
    return {'masked_lm_accuracy': masked_lm_accuracy, 'next_sentence_accuracy': next_sentence_accuracy}

In [31]:
# 학습 루프 정의
@jax.jit
def train_step(state, batch, dropout_key):
    dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
    def loss_fn(params):
        prediction_scores, seq_relationship_score = state.apply_fn(params, batch['input_ids'], batch['attention_mask'], batch['token_type_ids'], rngs={'dropout': dropout_train_key})
        loss = compute_loss(prediction_scores, seq_relationship_score, batch['labels'], batch['next_sentence_label'])
        return loss, (prediction_scores, seq_relationship_score)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (prediction_scores, seq_relationship_score)), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(prediction_scores, seq_relationship_score, batch['labels'], batch['next_sentence_label'])
    return state, metrics


In [32]:
@jax.jit
def eval_step(state, batch, dropout_key):
    dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
    prediction_scores, seq_relationship_score = state.apply_fn(state.params, batch['input_ids'], batch['attention_mask'], batch['token_type_ids'], rngs={'dropout': dropout_train_key})
    return compute_metrics(prediction_scores, seq_relationship_score, batch['labels'], batch['next_sentence_label'])

In [29]:
# Hyperparameters
num_epochs = 3
learning_rate = 1e-4

rng = jax.random.PRNGKey(0)

main_key, dropout_key = jax.random.split(rng)

model = BertForPreTraining(
    vocab_size=30522, hidden_size=768, max_length=512,
    num_heads=12, num_layers=12, intermediate_size=3072,
    type_vocab_size=2
)
state = create_train_state(main_key, model, learning_rate)

In [26]:
def batch_to_numpy(batch):
  return jax.tree_util.tree_map(lambda x: x.numpy(), batch)

In [35]:
# 학습 루프
for epoch in range(num_epochs):
    # Training
    for batch in train_loader:
        batch = batch_to_numpy(batch)
        state, train_metrics = train_step(state, batch, dropout_key)
    print(f"Epoch {epoch + 1}, Train Metrics: {train_metrics}")

    # Evaluation
    eval_metrics = []
    for batch in test_loader:
        batch = batch_to_numpy(batch)
        metrics = eval_step(state, batch, dropout_key)
        eval_metrics.append(metrics)
    eval_metrics = jax.device_get(eval_metrics)
    eval_metrics = {k: np.mean([metrics[k] for metrics in eval_metrics]) for k in eval_metrics[0].keys()}

    print(f"Epoch {epoch + 1}, Eval Metrics: {eval_metrics}")

Epoch 1, Train Metrics: {'masked_lm_accuracy': Array(0.13867188, dtype=float32), 'next_sentence_accuracy': Array(1., dtype=float32)}
Epoch 1, Eval Metrics: {'masked_lm_accuracy': np.float32(0.13005371), 'next_sentence_accuracy': np.float32(0.575)}
Epoch 2, Train Metrics: {'masked_lm_accuracy': Array(0.15917969, dtype=float32), 'next_sentence_accuracy': Array(0.25, dtype=float32)}
Epoch 2, Eval Metrics: {'masked_lm_accuracy': np.float32(0.13261719), 'next_sentence_accuracy': np.float32(0.575)}
Epoch 3, Train Metrics: {'masked_lm_accuracy': Array(0.15332031, dtype=float32), 'next_sentence_accuracy': Array(0.75, dtype=float32)}
Epoch 3, Eval Metrics: {'masked_lm_accuracy': np.float32(0.13398437), 'next_sentence_accuracy': np.float32(0.4875)}
