In [0]:
%pip install --upgrade torch transformers peft accelerate datasets evaluate
%pip install py7zr rouge_score entmax
dbutils.library.restartPython()

Python interpreter will be restarted.
Python interpreter will be restarted.
Python interpreter will be restarted.
Python interpreter will be restarted.


In [0]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [0]:
top_k = None
router_type = "linear"
norm_type   = "softmax" 

In [0]:
# --------------------------- 1.  imports ---------------------------
import evaluate, nltk, torch, torch.nn as nn
import numpy as np
from nltk.tokenize import sent_tokenize
from datasets import load_dataset, concatenate_datasets
from transformers import (
    AutoTokenizer,
    SwitchTransformersForConditionalGeneration,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

# -------------------------------------------------------------------
# 2.  Sparsemax + attention router
# -------------------------------------------------------------------
class SoftmaxNorm(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim
    def forward(self, x):
        return torch.softmax(x, dim=self.dim)
    
class Sparsemax(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        x = x - x.mean(dim=self.dim, keepdim=True)
        z = torch.clamp(x, min=0)                               # ReLU
        z_sorted, _ = torch.sort(z, dim=self.dim, descending=True)
        z_cumsum = z_sorted.cumsum(dim=self.dim)
        rhos = torch.arange(1, z.size(self.dim) + 1,
                            device=x.device, dtype=x.dtype)
        condition = z_sorted * rhos > (z_cumsum - 1)
        rho = condition.sum(dim=self.dim, keepdim=True)
        tau = (z_cumsum.gather(self.dim, rho - 1) - 1) / rho
        return torch.clamp(z - tau, min=0)
    
class LinearRouter(nn.Module):
    def __init__(self, config, norm: nn.Module, top_k: int = None):
        super().__init__()
        self.gate = nn.Linear(config.hidden_size, config.num_experts)
        self.norm = norm
        self.top_k = top_k if top_k is not None else config.num_experts

    def forward(self, hidden_states):
        logits = self.gate(hidden_states)                # [B,T,E]
        probs = self.norm(logits)                        # [B,T,E]
        
        if self.top_k < probs.size(-1):
            # Select top-k probabilities and zero out the rest
            top_k_probs, top_k_indices = torch.topk(probs, k=self.top_k, dim=-1)  # [B,T,k]
            # Create a zero tensor with the same shape as probs
            sparse_probs = torch.zeros_like(probs)  # [B,T,E]
            # Scatter top-k probabilities back to their original indices
            sparse_probs.scatter_(-1, top_k_indices, top_k_probs)
            probs = sparse_probs

        z_loss = torch.logsumexp(logits, -1).pow(2).mean()
        return probs, z_loss
    
class AttentionRouter(nn.Module):
    def __init__(self, config, norm: nn.Module, top_k: int = None):
        super().__init__()
        self.query = nn.Linear(config.hidden_size, config.hidden_size)
        self.expert_keys = nn.Parameter(torch.randn(config.num_experts, config.hidden_size))
        nn.init.normal_(self.expert_keys, mean=0., std=0.02)
        self.norm = norm
        self.top_k = top_k if top_k is not None else config.num_experts

    def forward(self, hidden_states):
        B, T, H = hidden_states.shape
        q = self.query(hidden_states).view(-1, H)     # [B*T, H]
        logits = (q @ self.expert_keys.T).view(B, T, -1)  # [B,T,E]
        probs = self.norm(logits)                    # [B,T,E]
        
        if self.top_k < probs.size(-1):
            # Select top-k probabilities and zero out the rest
            top_k_probs, top_k_indices = torch.topk(probs, k=self.top_k, dim=-1)  # [B,T,k]
            # Create a zero tensor with the same shape as probs
            sparse_probs = torch.zeros_like(probs)  # [B,T,E]
            # Scatter top-k probabilities back to their original indices
            sparse_probs.scatter_(-1, top_k_indices, top_k_probs)
            probs = sparse_probs

        z_loss = torch.logsumexp(logits, -1).pow(2).mean()
        return probs, z_loss

def make_router(config, top_k: int = None):
    norm = SoftmaxNorm(dim=-1) if norm_type == "softmax" else Sparsemax(dim=-1)
    if router_type == "linear":
        return LinearRouter(config, norm, top_k=top_k)
    elif router_type == "attention":
        return AttentionRouter(config, norm, top_k=top_k)
    else:
        raise ValueError(f"Unknown router_type={router_type}")

# -------------------------------------------------------------------
# 3.  Monkey-patch the MoE layer: full sparse weighted sum
# -------------------------------------------------------------------
from transformers.models.switch_transformers import modeling_switch_transformers as mst

def sparsemlp_forward(self, hidden_states):
    """
    Fully-sparse mix of expert outputs using Sparsemax router probabilities.
    Returns (combined_hidden_states,
             (router_logits, dummy_expert_indices))
    so the HuggingFace unpacker is satisfied.
    """
    # 1) run every expert
    if isinstance(self.experts, nn.ModuleDict):
        expert_outputs = [m(hidden_states) for m in self.experts.values()]
    else:                                            # ModuleList
        expert_outputs = [m(hidden_states) for m in self.experts]
    expert_outputs = torch.stack(expert_outputs, dim=2)   # [B,T,E,H]

    # 2) routing
    router_probs, router_z_loss = self.router(hidden_states)   # [B,T,E]

    # 3) weighted mixture
    combined = (expert_outputs * router_probs.unsqueeze(-1)).sum(dim=2)

    # 4) create dummy expert-index tensor so shape rules match
    dummy_idx = torch.zeros_like(router_probs[..., 0], dtype=torch.long)

    # 5) return in the format HF expects
    return combined, (router_probs, dummy_idx)

# Replace the original forward
mst.SwitchTransformersSparseMLP.forward = sparsemlp_forward

# -------------------------------------------------------------------
# 4.  NLTK data & data set
# -------------------------------------------------------------------
nltk.download("punkt")
model_id   = "google/switch-base-16"
dataset_id = "samsum"

raw = load_dataset(dataset_id)
tok  = AutoTokenizer.from_pretrained(model_id)

# determine max lengths
tmp = concatenate_datasets([raw["train"], raw["test"]])
max_src = max(len(t) for t in tmp.map(lambda x: tok(x["dialogue"],
                                                   truncation=True),
                                      batched=True)["input_ids"])
max_tgt = max(len(t) for t in tmp.map(lambda x: tok(x["summary"],
                                                   truncation=True),
                                      batched=True)["input_ids"])

def preprocess(batch, padding="max_length"):
    inputs = ["summarize: " + d for d in batch["dialogue"]]
    model_in = tok(inputs, max_length=max_src, truncation=True,
                   padding=padding)
    with tok.as_target_tokenizer():
        labels = tok(batch["summary"], max_length=max_tgt,
                     truncation=True, padding=padding)["input_ids"]
    labels = [[t if t != tok.pad_token_id else -100 for t in lab]
              for lab in labels]
    model_in["labels"] = labels
    return model_in

data = raw.map(preprocess, batched=True,
               remove_columns=["dialogue", "summary", "id"])

# -------------------------------------------------------------------
# 5.  Load model & replace all routers with our Sparsemax router
# -------------------------------------------------------------------
model = SwitchTransformersForConditionalGeneration.from_pretrained(model_id)

# print("Pretrained model:")
# print(model)
def replace_routers(model, top_k: int = None):
    from transformers.models.switch_transformers import modeling_switch_transformers as mst
    for module in model.modules():
        if isinstance(module, mst.SwitchTransformersSparseMLP):
            module.router = make_router(model.config, top_k=top_k)
replace_routers(model)
print("✅ All routers replaced")
# print("New model:")
# print(model)

# -------------------------------------------------------------------
# 6.  Trainer setup
# -------------------------------------------------------------------
rouge = evaluate.load("rouge")
def postproc(preds, refs):
    preds = ["\n".join(sent_tokenize(p.strip())) for p in preds]
    refs  = ["\n".join(sent_tokenize(r.strip())) for r in refs]
    return preds, refs

def metrics(eval_pred):
    preds, lab = eval_pred
    if isinstance(preds, tuple): preds = preds[0]
    preds = np.where(preds != -100, preds, tok.pad_token_id)
    dec_preds = tok.batch_decode(preds, skip_special_tokens=True)
    lab = np.where(lab != -100, lab, tok.pad_token_id)
    dec_lab = tok.batch_decode(lab, skip_special_tokens=True)
    dec_preds, dec_lab = postproc(dec_preds, dec_lab)
    res = rouge.compute(predictions=dec_preds, references=dec_lab,
                        use_stemmer=True)
    return {k: round(v*100,4) for k,v in res.items()}

collator = DataCollatorForSeq2Seq(tok, model, label_pad_token_id=-100,
                                  pad_to_multiple_of=4)

args = Seq2SeqTrainingArguments(
    output_dir="switch8",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    learning_rate=5e-5,
    predict_with_generate=True,
    fp16=False,
    logging_steps=500,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    report_to="none",
    save_safetensors=False
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    tokenizer=tok,
    data_collator=collator,
    train_dataset=data["train"],
    eval_dataset=data["validation"],
    compute_metrics=metrics,
)

# -------------------------------------------------------------------
# 7. Training
# -------------------------------------------------------------------
print("\n🧪  Eval before training (first 50 examples):")
print(trainer.evaluate(eval_dataset=data["validation"].select(range(50))))

# Train the model
print("\n🚀  Start fine-tuning …")
print(trainer.train())

# -------------------------------------------------------------------
# 8. Saving and evaluate
# -------------------------------------------------------------------

# Save the model
save_dir = f"/dbfs/switch_base_8/{router_type}_{norm_type}_top{top_k}"
trainer.save_model(save_dir)
print(f"Model saved to {save_dir}")

# Evaluate the results
print("\n🧪 Evaluating on test set:")
test_results = trainer.evaluate()
print("Test set results:", test_results)

# Print GPU memory usage
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

  warn(f"Failed to load image Python extension: {e}")
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Map:   0%|          | 0/819 [00:00<?, ? examples/s]

✅ All routers replaced
  trainer = Seq2SeqTrainer(

🧪  Eval before training (first 50 examples):
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


{'eval_loss': 8.274967193603516, 'eval_model_preparation_time': 0.0149, 'eval_rouge1': 0.3111, 'eval_rouge2': 0.0, 'eval_rougeL': 0.3111, 'eval_rougeLsum': 0.3111, 'eval_runtime': 13.3415, 'eval_samples_per_second': 3.748, 'eval_steps_per_second': 0.974}

🚀  Start fine-tuning …


Epoch,Training Loss,Validation Loss,Model Preparation Time,Rouge1,Rouge2,Rougel,Rougelsum
1,2.0388,1.650044,0.0149,44.0832,20.2509,36.6896,40.3614
2,1.78,1.556581,0.0149,45.9901,22.5954,38.7357,42.4854
3,1.5988,1.514346,0.0149,47.0176,23.3365,39.5096,43.3715
4,1.4712,1.502311,0.0149,47.2851,24.0115,39.9804,43.8309
5,1.3969,1.501981,0.0149,47.5116,24.2364,40.0738,44.0178


TrainOutput(global_step=18415, training_loss=1.7618367369395702, metrics={'train_runtime': 12664.1537, 'train_samples_per_second': 5.816, 'train_steps_per_second': 1.454, 'total_flos': 2.3708237782450176e+17, 'train_loss': 1.7618367369395702, 'epoch': 5.0})
Model saved to /dbfs/switch_base_8/linear_softmax_topNone

🧪 Evaluating on test set:


Test set results: {'eval_loss': 1.5019805431365967, 'eval_model_preparation_time': 0.0149, 'eval_rouge1': 47.5116, 'eval_rouge2': 24.2364, 'eval_rougeL': 40.0738, 'eval_rougeLsum': 44.0178, 'eval_runtime': 198.2125, 'eval_samples_per_second': 4.127, 'eval_steps_per_second': 1.034, 'epoch': 5.0}
GPU Memory Allocated: 12.91 GB
