In [1]:
# moe.p y
import torch
import torch.nn as nn
from typing import Optional
from transformers import BertPreTrainedModel, BertModel
from transformers.models.bert.modeling_bert import (
    BertLayer,
    BertOutput,
    BertLMPredictionHead,
)


class MoEFFN(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_experts: int = 4,
        expert_size: Optional[int] = None,
        k: int = 2,
        dropout_prob: float = 0.1,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_experts = num_experts
        self.k = k
        self.expert_size = expert_size or hidden_size * 4

        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, self.expert_size),
                nn.GELU(),
                nn.Dropout(dropout_prob),
                nn.Linear(self.expert_size, hidden_size),
            )
            for _ in range(num_experts)
        ])

        self.gate = nn.Linear(hidden_size, num_experts, bias=False)

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.shape
        assert hidden_dim == self.hidden_size

        x = hidden_states.view(-1, hidden_dim)  # [N, H]
        gate_logits = self.gate(x)  # [N, E]
        top_k_logits, top_k_indices = torch.topk(gate_logits, self.k, dim=1)  # [N, k]
        top_k_weights = torch.softmax(top_k_logits, dim=1)  # [N, k]

        final_output = torch.zeros_like(x)

        for i in range(self.num_experts):
            expert_mask = (top_k_indices == i)  # [N, k]
            if expert_mask.any():
                token_indices = expert_mask.nonzero(as_tuple=True)[0]  # [M]
                pos_in_topk = expert_mask.nonzero(as_tuple=True)[1]    # [M]

                expert_inputs = x[token_indices]  # [M, H]
                expert_weights = top_k_weights[token_indices, pos_in_topk]  # [M]
                expert_out = self.experts[i](expert_inputs)  # [M, H]
                weighted_out = expert_out * expert_weights.unsqueeze(-1)  # [M, H]

                final_output.index_add_(0, token_indices, weighted_out)

        return final_output.view(batch_size, seq_len, hidden_dim)


from transformers.models.bert.modeling_bert import BertLayer
import torch.nn as nn

class BertLayerWithMoE(BertLayer):
    def __init__(self, config):
        super().__init__(config)
        # –£–¥–∞–ª—è–µ–º —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω—ã–π FFN
        del self.intermediate
        del self.output

        self.moe_ffn = MoEFFN(
            hidden_size=config.hidden_size,
            num_experts=getattr(config, "num_experts", 4),
            expert_size=config.intermediate_size,  # –∏—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è –≤–Ω—É—Ç—Ä–∏ MoE
            k=getattr(config, "moe_k", 2),
            dropout_prob=config.hidden_dropout_prob,
        )

        # –í–º–µ—Å—Ç–æ BertOutput ‚Äî —Å–æ–∑–¥–∞—ë–º —Å–≤–æ–π –ø—Ä–æ—Å—Ç–æ–π LayerNorm + Dropout
        self.moe_output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.moe_output_dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
        **kwargs,
    ):
        self_attn_output = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        attn_output = self_attn_output[0]

        moe_output = self.moe_ffn(attn_output)  # [B, L, hidden_size]

        # Residual + Dropout + LayerNorm (–∫–∞–∫ –≤ –æ—Ä–∏–≥–∏–Ω–∞–ª—å–Ω–æ–º BERT)
        moe_output = self.moe_output_dropout(moe_output)
        layer_output = self.moe_output_layer_norm(attn_output + moe_output)

        outputs = (layer_output,) + self_attn_output[1:]
        return outputs


class BertMoEForMaskedLM(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # –°–æ–∑–¥–∞—ë–º BERT –∏ –∑–∞–º–µ–Ω—è–µ–º —Å–ª–æ–∏ –Ω–∞ MoE
        self.bert = BertModel(config, add_pooling_layer=False)
        for layer in self.bert.encoder.layer:
            layer.__class__ = BertLayerWithMoE
            layer.__init__(config)

        self.cls = BertLMPredictionHead(config)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        **kwargs,
    ):
        # –ü–µ—Ä–µ–¥–∞—ë–º –¢–û–õ–¨–ö–û –ø–æ–¥–¥–µ—Ä–∂–∏–≤–∞–µ–º—ã–µ –∞—Ä–≥—É–º–µ–Ω—Ç—ã –≤ BertModel
        bert_kwargs = {
            k: v for k, v in {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
                "position_ids": position_ids,
                "head_mask": head_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }.items() if v is not None
        }

        outputs = self.bert(**bert_kwargs)

        sequence_output = outputs.last_hidden_state
        prediction_scores = self.cls(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1)
            )

        return {
            "loss": loss,
            "logits": prediction_scores,
            "hidden_states": outputs.hidden_states,
            "attentions": outputs.attentions,
        }

In [19]:
class PretrainConfig:
    model_name = "your-moe-bert"
    dataset_name = "wikimedia/wikipedia"
    dataset_config = "20231101.en"
    text_column = "text"
    tokenizer = "bert-base-uncased"
    output_dir = "."
    seq_len = 128
    batch_size = 32

    masking_prob = 0.15

    lr = 5e-5
    weight_decay = 0.01
    warmup_steps = 1000
    max_steps = 50_000

    save_steps = 5_000
    logging_steps = 100
    eval_steps = 5_000

    # BERT / MoE –ø–∞—Ä–∞–º–µ—Ç—Ä—ã
    bert_hidden_size = 256
    bert_intermediate_size = 1024
    bert_num_hidden_layers = 4
    bert_num_attention_heads = 4
    num_experts = 4


class ClassificationConfig:
    # –ü—É—Ç–∏
    backbone_path = "./final_model/pytorch_model.bin"  # –∏–ª–∏ –≤–∞—à .pt —Ñ–∞–π–ª
    dataset_path = "your_dataset"  # –Ω–∞–ø—Ä–∏–º–µ—Ä, "json", "csv", –∏–ª–∏ HuggingFace dataset
    dataset_split = "train"
    text_column = "text"
    label_columns = ["label1", "label2", "label3", "label4", "label5", "label6"]  # ‚Üê 6 –º–µ—Ç–æ–∫

    # –ú–æ–¥–µ–ª—å
    num_labels = 6
    tokenizer = "bert-base-uncased"
    hidden_size = 256
    num_hidden_layers = 4
    num_attention_heads = 4
    intermediate_size = 1024
    num_experts = 4

    # –û–±—É—á–µ–Ω–∏–µ
    output_dir = "./cls_output"
    seq_len = 128
    batch_size = 32
    lr = 2e-5
    weight_decay = 0.01
    num_train_epochs = 3
    logging_steps = 100
    save_steps = 500
    eval_steps = 500

In [None]:
import os
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers import BertConfig


def main():
    set_seed(42)
    cfg = PretrainConfig()

    # --- –ú–æ–¥–µ–ª—å –∏ —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä ---
    model_config = BertConfig(
        vocab_size=30522,
        hidden_size=cfg.bert_hidden_size,
        num_hidden_layers=cfg.bert_num_hidden_layers,
        num_attention_heads=cfg.bert_num_attention_heads,
        intermediate_size=cfg.bert_intermediate_size,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        pad_token_id=0,
        num_experts=cfg.num_experts,
        moe_k=2,
    )

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer)
    model = BertMoEForMaskedLM(model_config)

    print(f"Model initialized with {cfg.num_experts} experts.")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    # --- –ó–ê–ì–†–£–ó–ö–ê –î–ê–¢–ê–°–ï–¢–ê –í STREAMING –†–ï–ñ–ò–ú–ï ---
    print("Loading dataset in streaming mode...")
    # dataset = load_dataset(
    #     cfg.dataset_name,
    #     cfg.dataset_config,
    #     split="train",
    #     streaming=True  # üî• –∫–ª—é—á–µ–≤–æ–µ –∏–∑–º–µ–Ω–µ–Ω–∏–µ!
    # )

    train_dataset = load_dataset(
    cfg.dataset_name,
    cfg.dataset_config,
    split="train",
    streaming=True# –ø—Ä–æ—Å—Ç–æ train
    )

    raw_dataset = load_dataset(
    cfg.dataset_name,
    cfg.dataset_config,
    split="train",
    streaming=False# –ø—Ä–æ—Å—Ç–æ train
    )

    ds = raw_dataset.train_test_split(
        test_size=0.01,
        seed=42,
    )

    eval_raw = ds["test"]

    def tokenize_function(examples):
        return tokenizer(
            examples[cfg.text_column],
            truncation=True,
            padding=False,  # collator —Å–∞–º —Å–¥–µ–ª–∞–µ—Ç padding –¥–æ batch max
            max_length=cfg.seq_len,
            return_special_tokens_mask=True,
        )

    # –ü—Ä–∏–º–µ–Ω—è–µ–º —Ç–æ–∫–µ–Ω–∏–∑–∞—Ü–∏—é –∏ —É–¥–∞–ª—è–µ–º –í–°–ï –∏—Å—Ö–æ–¥–Ω—ã–µ –∫–æ–ª–æ–Ω–∫–∏
    # original_columns = dataset.column_names  # ['id', 'text', 'url'] ‚Äî –¥–ª—è wikipedia
    # tokenized_dataset = dataset.map(
    #     tokenize_function,
    #     batched=True,
    #     remove_columns=original_columns,  # ‚Üê —É–¥–∞–ª—è–µ–º –í–°–Å, –∫—Ä–æ–º–µ output tokenizer'–∞
    # )
    tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    )

    tokenized_eval = eval_raw.map(
        tokenize_function,
        batched=True,
        remove_columns=eval_raw.column_names,
    )

    # --- Data collator ---
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=cfg.masking_prob,
    )

    # --- Training args ---
    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        overwrite_output_dir=True,
        max_steps=cfg.max_steps,
        per_device_train_batch_size=cfg.batch_size,
        gradient_accumulation_steps=1,
        learning_rate=cfg.lr,
        weight_decay=cfg.weight_decay,
        warmup_steps=cfg.warmup_steps,
        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        save_strategy="steps",
        load_best_model_at_end=False,
        fp16=True,
        dataloader_num_workers=2,  # –º–æ–∂–Ω–æ 0‚Äì4, –Ω–æ –≤ streaming –ª—É—á—à–µ 0‚Äì2
        remove_unused_columns=False,
        report_to="none",
        # ‚ö†Ô∏è –í–ê–ñ–ù–û: –æ—Ç–∫–ª—é—á–∞–µ–º shuffle –¥–ª—è streaming (–∏–ª–∏ –∏—Å–ø–æ–ª—å–∑—É–µ–º –±—É—Ñ–µ—Ä)
        dataloader_drop_last=True,
        save_safetensors=False,
        eval_strategy="steps",
        eval_steps=cfg.eval_steps,
    )

    # --- –°–æ–∑–¥–∞—ë–º Trainer ---
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,  # ‚Üê streaming dataset!
        eval_dataset=tokenized_eval,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    # --- –û–±—É—á–µ–Ω–∏–µ ---
    print("Starting pretraining (streaming)...")
    trainer.train()

    # --- –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ ---
    final_dir = os.path.join(cfg.output_dir, "final_model")
    trainer.save_model(final_dir)
    tokenizer.save_pretrained(final_dir)
    print(f"Model saved to {final_dir}")


if __name__ == "__main__":
    main()

Model initialized with 4 experts.
Total parameters: 25,326,138
Loading dataset in streaming mode...


Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Map:   0%|          | 0/64079 [00:00<?, ? examples/s]

  trainer = Trainer(


Starting pretraining (streaming)...


Step,Training Loss,Validation Loss
5000,6.8027,6.874363


In [5]:
import torch
from transformers import AutoTokenizer

model_dir = "/content/checkpoint-10000"  # –∏–ª–∏ –∫–æ–Ω–∫—Ä–µ—Ç–Ω—ã–π –ø—É—Ç—å

# 1. –ó–∞–≥—Ä—É–∂–∞–µ–º —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä
tokenizer = AutoTokenizer.from_pretrained(model_dir)

# 2. –ó–∞–≥—Ä—É–∂–∞–µ–º –º–æ–¥–µ–ª—å
model = BertMoEForMaskedLM.from_pretrained(model_dir)
model.eval()  # –≤—ã–∫–ª—é—á–∞–µ–º dropout –∏ —Ç.–ø.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

import torch

text = "Paris is the [MASK] of France."

inputs = tokenizer(
    text,
    return_tensors="pt"
)

inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs)   # —Ç–≤–æ–π forward –≤–æ–∑–≤—Ä–∞—â–∞–µ—Ç dict
    logits = outputs["logits"]  # [batch_size, seq_len, vocab_size]

# –ù–∞—Ö–æ–¥–∏–º –ø–æ–∑–∏—Ü–∏—é [MASK]
mask_token_id = tokenizer.mask_token_id
mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=False)

# –î–ª—è –ø—Ä–æ—Å—Ç–æ—Ç—ã –ø—Ä–µ–¥–ø–æ–ª–∞–≥–∞–µ–º –æ–¥–∏–Ω [MASK] –≤ –æ–¥–Ω–æ–º –ø—Ä–∏–º–µ—Ä–µ
batch_idx, mask_pos = mask_positions[0].tolist()

mask_logits = logits[batch_idx, mask_pos, :]       # [vocab_size]
top_k = torch.topk(mask_logits, k=20)               # —Ç–æ–ø-5 —Ç–æ–∫–µ–Ω–æ–≤
top_ids = top_k.indices.tolist()
top_scores = top_k.values.tolist()

print("Input:", text)
print("Top predictions for [MASK]:")
for token_id, score in zip(top_ids, top_scores):
    token = tokenizer.decode([token_id])
    print(f"{token!r}  logit={score:.3f}")


Input: Paris is the [MASK] of France.
Top predictions for [MASK]:
'a'  logit=8.039
'the'  logit=7.469
'is'  logit=7.435
'of'  logit=6.474
'in'  logit=6.015
'an'  logit=5.873
'-'  logit=5.772
','  logit=5.717
'.'  logit=5.342
'('  logit=4.887
'and'  logit=4.829
'or'  logit=4.449
'district'  logit=4.437
'are'  logit=4.394
'##s'  logit=4.357
'##i'  logit=4.321
'was'  logit=4.299
'to'  logit=4.258
"'"  logit=4.152
'##o'  logit=4.149
