# Preprocessing

## Environment & imports

In [1]:
# ---------------------------------------------
# 0)  Install (first‐time only) & import libs
# ---------------------------------------------
# !pip install -q datasets transformers emoji==2.10.0 tqdm

from pathlib import Path
import re
import random
import json
from collections import defaultdict
from typing import List, Dict, Tuple

import emoji
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Load SAMSum

In [2]:
# ---------------------------------------------------------
# 1) Load SAMSum — 14 732 / 819 / 818 dialogues
# ---------------------------------------------------------
raw_ds: DatasetDict = load_dataset("samsum")
print({k: len(v) for k, v in raw_ds.items()})

Using the latest cached version of the module from /home/drl-68/.cache/huggingface/modules/datasets_modules/datasets/samsum/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e (last modified on Mon May  5 10:16:55 2025) since it couldn't be found locally at samsum, or remotely on the Hugging Face Hub.


{'train': 14732, 'test': 819, 'validation': 818}


## Build an emoji vocabulary and speaker token & Build / extend the tokenizer

count [UNK] occurrences in one HF Dataset

In [3]:
from tqdm import tqdm
import numpy as np
import torch

def count_unk(ds, tokenizer, field="dialogue", batch_size=1024):
    unk_id = tokenizer.unk_token_id
    total_unk, total_tokens = 0, 0

    for i in tqdm(range(0, len(ds), batch_size), desc="Tokenising"):
        batch_texts = ds[i : i + batch_size][field]
        enc = tokenizer(batch_texts, add_special_tokens=True, padding=False, truncation=False)
        for ids in enc["input_ids"]:
            arr = np.array(ids)
            total_unk += np.sum(arr == unk_id)
            total_tokens += len(arr)
    return total_unk, total_tokens

BEFORE adding emojis

In [4]:
tok_base = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama_v1.1")
unk_stats_before = {}
for split in ["train", "validation", "test"]:
    unk_stats_before[split] = count_unk(raw_ds[split], tok_base)
print("\n[UNK] counts BEFORE adding emoji tokens")
for split, (u, t) in unk_stats_before.items():
    print(f"{split:<10}: {u:8d}  ({u/t:.3%} of tokens)")

Tokenising: 100%|██████████| 15/15 [00:00<00:00, 26.71it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 42.07it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 41.07it/s]


[UNK] counts BEFORE adding emoji tokens
train     :        0  (0.000% of tokens)
validation:        0  (0.000% of tokens)
test      :        0  (0.000% of tokens)





สร้าง EMOJI_TOKENS

In [5]:
# ถ้า kernel เพิ่งรีสตาร์ต ตัวแปรจะหายหมด
# สร้างชุด emoji ใหม่จาก raw_ds
from typing import List
import emoji

def extract_emojis(text: str) -> List[str]:
    return [ch for ch in text if ch in emoji.EMOJI_DATA]

emoji_set = set()
for split in ["train", "validation", "test"]:
    for dlg in raw_ds[split]["dialogue"]:
        emoji_set.update(extract_emojis(dlg))

EMOJI_TOKENS = sorted(emoji_set)          # ≈ 300-320 รายการ
print(f"Unique emojis found: {len(EMOJI_TOKENS)}")

Unique emojis found: 305


In [6]:
# # Use a pipeline as a high-level helper
# from transformers import pipeline

# pipe = pipeline("text-generation", model="meta-llama/Llama-3.2-1B")

Extend tokenizer with emojis + speaker tags

In [7]:
from transformers import AutoTokenizer

# ---------- 1) โหลด tokenizer ดั้งเดิม ----------
tok_base = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama_v1.1")
vocab_orig = len(tok_base)

# ---------- 2) เตรียมชุด token ใหม่ ----------
#   • EMOJI_TOKENS  : ทุกอิโมจิที่ “พบอย่างน้อย 1 ครั้ง” ใน SAMSum
#   • SPEAKER_TOKENS: [S1] – [S10]
SPEAKER_TOKENS = [f"[S{i}]" for i in range(1, 11)]
new_tokens = EMOJI_TOKENS + SPEAKER_TOKENS

# ---------- 3) สร้าง tokenizer สำเนาแล้วเพิ่ม token ----------
tok_ext = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama_v1.1")
added = tok_ext.add_tokens(new_tokens)
vocab_new = len(tok_ext)

# ---------- 4) แสดงผล ----------
print(f"Original vocab size : {vocab_orig}")
print(f"Added new tokens     : {added}  "
      f"(emoji = {len(EMOJI_TOKENS)}, speaker = {len(SPEAKER_TOKENS)})")
print(f"New vocab size       : {vocab_new}")

# (Optional) พิมพ์ตัวอย่างอิโมจิ 20 ตัวแรก
print("\nFirst 20 emoji tokens:", EMOJI_TOKENS[:20])

tok_ext.save_pretrained("tokenizer_samsum_su")   # โฟลเดอร์ใหม่

Original vocab size : 32000
Added new tokens     : 315  (emoji = 305, speaker = 10)
New vocab size       : 32310

First 20 emoji tokens: ['‼', '⏱', '☀', '☂', '☔', '☕', '☘', '☝', '☠', '☢', '☹', '☺', '♀', '♂', '♥', '♻', '⚪', '⚫', '⚰', '⚽']


('tokenizer_samsum_su/tokenizer_config.json',
 'tokenizer_samsum_su/special_tokens_map.json',
 'tokenizer_samsum_su/tokenizer.model',
 'tokenizer_samsum_su/added_tokens.json',
 'tokenizer_samsum_su/tokenizer.json')

AFTER adding emojis

In [8]:
unk_stats_after = {}
for split in ["train", "validation", "test"]:
    unk_stats_after[split] = count_unk(raw_ds[split], tok_ext)
print("\n[UNK] counts AFTER adding emoji tokens")
for split, (u, t) in unk_stats_after.items():
    print(f"{split:<10}: {u:8d}  ({u/t:.3%} of tokens)")

Tokenising: 100%|██████████| 15/15 [00:00<00:00, 28.52it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 41.31it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 38.35it/s]


[UNK] counts AFTER adding emoji tokens
train     :        0  (0.000% of tokens)
validation:        0  (0.000% of tokens)
test      :        0  (0.000% of tokens)





reduction check in UNKs

In [9]:
print("\nΔ [UNK] (before ➜ after):")
for split in ["train", "validation", "test"]:
    u0, _ = unk_stats_before[split]
    u1, _ = unk_stats_after[split]
    print(f"{split:<10}: {u0-u1:+d}  fewer UNKs  (↓{(u0-u1)/u0:.2%})")


Δ [UNK] (before ➜ after):
train     : +0  fewer UNKs  (↓nan%)
validation: +0  fewer UNKs  (↓nan%)
test      : +0  fewer UNKs  (↓nan%)


  print(f"{split:<10}: {u0-u1:+d}  fewer UNKs  (↓{(u0-u1)/u0:.2%})")


## Preprocess SAMSum Dateset

Speaker-name mapping → [S#]

In [10]:
# ---------------------------------------------------------
# 4) Helper to replace speaker names by [S#]
# ---------------------------------------------------------
SPEAKER_RE = re.compile(r"^([^:]+):\s*(.*)$")

def map_speakers(dialogue: str, max_speakers: int = 10
                 ) -> Tuple[str, Dict[str, str]]:
    """
    Returns dialogue with names replaced by [S#] and a mapping dict.
    """
    speaker_map, next_id = {}, 1
    new_lines = []
    for line in dialogue.split("\n"):
        m = SPEAKER_RE.match(line)
        if not m:                # safety – keep line as is
            new_lines.append(line)
            continue
        name, utt = m.groups()
        if name not in speaker_map:
            if next_id > max_speakers:      # truncate extra speakers
                name_token = "[SUNK]"
            else:
                name_token = f"[S{next_id}]"
                speaker_map[name] = name_token
                next_id += 1
        new_lines.append(f"{speaker_map.get(name, '[SUNK]')}: {utt}")
    return "\n".join(new_lines), speaker_map


Insert [SEP] after every utterance

In [11]:
def add_sep_every_utt(dialogue: str) -> str:
    lines = [l + " [SEP]" for l in dialogue.split("\n") if l.strip()]
    return " ".join(lines)

Switching-Utterance corruption
- Hyper-parameters: Pu = 1.0, Pn = 0/1

โดยที่

Pu (permute-utterance prob.) ความน่าจะเป็นที่ แต่ละ utterance จะถูกเลือก ใส่ลงในชุดที่นำไปสับตำแหน่ง

- pu = 1.0 แสดงว่าบังคับเลือกทุกบรรทัดแล้วค่อยสับคำแบบสุ่ม

Pn (name-mask prob.) ความน่าจะเป็นที่ token [S#] ด้านหน้าจะถูกเปลี่ยนเป็น [MASK]

- pn = 0.0 แสดงว่า ไม่ mask, โมเดลเห็น speaker tag

- pn = 1.0 แสดงว่า mask หมด, บังคับดู context

In [12]:
def make_switching_utterance(dialogue: str,
                             pu: float = 1.0,
                             pn: float = 0.0,
                             rng: random.Random = random
                            ) -> Tuple[str, List[int]]:
    """
    • dialogue  - speaker-tokenised, SEP-inserted string
    • pu        - prob. an utterance is selected for permutation
    • pn        - prob. we MASK the speaker token (⇒ [MASK])
    Returns:
        corrupted_dialogue, labels_per_utt  (1 = permuted (สลับบทพูด), 0 = original)
    """
    # 1) split back into utterances
    utts = [u.strip() for u in dialogue.split("[SEP]") if u.strip()]
    idxs = list(range(len(utts)))

    # 2) pick indices to permute
    perm_idx = [i for i in idxs if rng.random() < pu]
    shuffled = perm_idx.copy()
    rng.shuffle(shuffled)                 # in-place
    perm_map = dict(zip(perm_idx, shuffled))

    # 3) build new utterance list, labels
    new_utts, labels = [], []
    for i in idxs:
        src = perm_map.get(i, i)          # swapped or same
        u = utts[src]
        # optionally mask speaker token ([S#]: → [MASK]:)
        if rng.random() < pn:
            u = re.sub(r"^\[S\d+\]", "[MASK]", u)
        new_utts.append(u)
        labels.append(int(src != i))      # 1 if permuted
    corrupted = " [SEP] ".join(new_utts) + " [SEP]"
    return corrupted, labels


## Switching-Utterance (SU) pre-training dataset

In [13]:
# ---------------------------------------------------------
# 7) Create HF Datasets with tokenised inputs, attention,
#    SEP positions, and per-utterance labels
# ---------------------------------------------------------
MAX_LEN = 512                          # paper setting
Pu, Pn = 1.0, 0.0                      # best config in Table 2


def preprocess_example(example, split):

    if tok_ext.pad_token is None:
        tok_ext.add_special_tokens({'pad_token': '[PAD]'})
        
    # a) replace speakers & add SEP
    dlg, _ = map_speakers(example["dialogue"])
    dlg = add_sep_every_utt(dlg)

    # b) corruption
    corrupted, labels = make_switching_utterance(dlg, Pu, Pn)

    # c) tokenize (truncate if >512 tokens)
    enc = tok_ext(corrupted,
              truncation=True, max_length=MAX_LEN,
              padding="max_length")
    
    # d) find SEP token positions (needed for loss later)
    sep_id = tok_ext("[SEP]")["input_ids"][0]
    sep_positions = [i for i, id_ in enumerate(enc["input_ids"])
                     if id_ == sep_id][:len(labels)]  # clip if truncated

    enc["labels"] = labels[:len(sep_positions)]
    enc["sep_positions"] = sep_positions
    enc["dialogue_len"] = len(labels)
    return enc

su_ds = DatasetDict()
for split in ["train", "validation", "test"]:
    su_ds[split] = raw_ds[split].map(
        preprocess_example,
        fn_kwargs={"split": split},
        remove_columns=raw_ds[split].column_names,
        desc=f"Building SU {split}"
    )

su_ds.save_to_disk("data/samsum_switching_utterance")
print(su_ds)

Building SU train: 100%|██████████| 14732/14732 [00:06<00:00, 2314.91 examples/s]
Building SU validation: 100%|██████████| 818/818 [00:00<00:00, 2314.77 examples/s]
Building SU test: 100%|██████████| 819/819 [00:00<00:00, 1859.38 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 14732/14732 [00:00<00:00, 455587.98 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 818/818 [00:00<00:00, 212916.76 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 819/819 [00:00<00:00, 221435.89 examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 14732
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 818
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 819
    })
})





ไฟล์ Arrow ถูกบันทึกไว้ที่ data/samsum_switching_utterance/ พร้อมฟิลด์ input_ids / attention_mask / labels / sep_positions / dialogue_len.

## Self-supervised Pre-training

ใช้ Dataset เฉพาะส่วนของ train ของ SAMSum มาทำการ pre_train แล้วใช้ validation ไว้ดู early-stopping / tuning ส่วน test ต้องไม่ถูกแตะ เพื่อไม่ให้โมเดล “เห็น” บทสนทนาที่จะใช้วัด ROUGE ภายหลัง

In [1]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

Imports & helpers

In [10]:
import math
import torch
import random
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    AutoModelForCausalLM
)
from datasets import load_from_disk

# -------------------------------
# CONFIG
# -------------------------------
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4 # paper 128
MAX_LEN    = 512 # paper 512
LR         = 3e-5
WARMUP     = 500
MAX_STEPS  = 40000

Dataset & collate

In [11]:
# -------------------------------
# LOAD DATASET
# -------------------------------
dataset = load_from_disk("data/samsum_switching_utterance")

# -------------------------------
# COLLATE FUNCTION
# -------------------------------
# def collate_fn(batch):
#     keys = ["input_ids", "token_type_ids", "attention_mask"]
#     inputs = {k: torch.tensor([b[k] for b in batch]) for k in keys}
#     labels = [torch.tensor(b["labels"], dtype=torch.float) for b in batch]
#     sep_pos = [torch.tensor(b["sep_positions"]) for b in batch]
#     return inputs, labels, sep_pos

def collate_fn(batch):
    keys = batch[0].keys()
    inputs = {
        k: torch.stack([torch.tensor(b[k]) if not isinstance(b[k], torch.Tensor) else b[k] for b in batch])
        for k in keys if k not in ["labels", "sep_positions"]
    }
    labels = [torch.tensor(b["labels"], dtype=torch.float) for b in batch]
    sep_pos = [torch.tensor(b["sep_positions"]) for b in batch]
    return inputs, labels, sep_pos



Model

In [12]:
# -------------------------------
# MODEL
# -------------------------------
class SepClassifier(nn.Module):
    def __init__(self, model_name="TinyLlama/TinyLlama_v1.1", dropout=0.1):
        super().__init__()
        config = AutoConfig.from_pretrained(model_name)
        self.llama = AutoModelForCausalLM.from_pretrained(
            model_name,
            config=config,
            # torch_dtype=torch.float16  # ใช้ float16 เฉพาะ LLaMA model
        )
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(config.hidden_size, 1)  

    def forward(self, input_ids, attention_mask, sep_positions, **kwargs):
        # เรียก model โดยให้คืนค่า hidden states ทั้งหมด
        outputs = self.llama(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        # เอา hidden state ชั้นสุดท้ายมาใช้
        hidden_states = outputs.hidden_states[-1]

        # ดึง hidden vectors ที่ตำแหน่ง sep ของแต่ละ sample
        sep_vecs = []
        for i, pos_tensor in enumerate(sep_positions):
            if pos_tensor.numel() == 0:
                continue  # ข้าม sample ที่ไม่มี SEP
            pos_tensor = pos_tensor.to(hidden_states.device).long()
            if torch.any(pos_tensor >= hidden_states.shape[1]):
                raise ValueError(f"Invalid sep_position index: {pos_tensor.tolist()} exceeds hidden size {hidden_states.shape[1]}")
            sep_vecs.append(hidden_states[i].index_select(0, pos_tensor))

        if not sep_vecs:
            raise ValueError("All sep_positions are empty. Cannot compute logits.")

        sep_vecs = torch.cat(sep_vecs, dim=0)
        sep_vecs = torch.nan_to_num(sep_vecs, nan=0.0, posinf=1e4, neginf=-1e4)
        sep_vecs = torch.clamp(sep_vecs, -1e4, 1e4)  # limit range

        # Apply dropout และ classifier
        logits = self.classifier(self.dropout(sep_vecs)).squeeze(-1)  # (total_seps,)
        logits = torch.clamp(logits, -30, 30)  # clamp เพื่อป้องกัน NaN ใน BCEWithLogitsLoss
        return logits


In [13]:
import torch
import gc

# เคลียร์ cache ของ CUDA
torch.cuda.empty_cache()

# เคลียร์ object ที่ไม่ใช้งานแล้ว
gc.collect()

# (Optional) หากใช้ AMP หรือ context manager อื่นๆ
if torch.cuda.is_available():
    torch.cuda.ipc_collect()  # สำหรับเคลียร์ inter-process memory (เฉพาะบางเคส)


Training loop (train model until the train loss converged (upper bounded by 5k steps)

In [15]:
from tqdm import tqdm
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# -------------------------------
# INITIALIZATION
# -------------------------------
torch.cuda.empty_cache()
# Load tokenizer and ensure it has all required special tokens
tokenizer = AutoTokenizer.from_pretrained("tokenizer_samsum_su")

# Add special tokens if missing
special_tokens = {}
if tokenizer.pad_token is None:
    special_tokens["pad_token"] = "[PAD]"
if tokenizer.sep_token is None:
    special_tokens["sep_token"] = "[SEP]"

if special_tokens:
    tokenizer.add_special_tokens(special_tokens)

# Initialize model
model = SepClassifier()
for param in model.llama.parameters():  
    param.requires_grad = False

# Resize token embeddings to match new tokenizer length
model.llama.resize_token_embeddings(len(tokenizer))

# Setup DataLoader
# train_loader = DataLoader(
#     dataset["train"],
#     batch_size=BATCH_SIZE,
#     shuffle=True,
#     drop_last=True,
#     collate_fn=collate_fn
# )

train_loader = DataLoader(
    dataset["train"],
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn,
    num_workers=4,  # increase if you're using CPU-heavy preprocessing
    pin_memory=True  # if using GPU
)


# Optimizer, scheduler, loss
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
total_steps = min(10000, len(train_loader))
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=500,
    num_training_steps=total_steps
)
loss_fn = nn.BCEWithLogitsLoss()

# -------------------------------
# TRAINING LOOP
# -------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(next(model.parameters()).device)
step = 0
running_loss = 0.0
model.train()

from torch.amp import autocast, GradScaler

scaler = GradScaler(device='cuda')  # <-- add before your loop

for epoch in range(100):
    print(f"Epoch {epoch}")
    torch.cuda.empty_cache()

    with tqdm(train_loader, desc=f"Training Epoch {epoch}") as t_loader:
        for inputs, label_lists, sep_lists in t_loader:
            if step >= MAX_STEPS:
                break

            inputs = {k: v.to(device) for k, v in inputs.items()}
            sep_lists = [s.to(device) for s in sep_lists]
            flat_labels = torch.cat(label_lists).float().to(device)

            optimizer.zero_grad(set_to_none=True)

            with autocast(device_type='cuda'):
                logits = model(**inputs, sep_positions=sep_lists)
                loss = loss_fn(logits, flat_labels)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            running_loss += loss.item()
            step += 1

            t_loader.set_postfix(step=step, loss=loss.item(), avg_loss=running_loss / max(1, step % 100))
            if step % 100 == 0:
                print(f"[Step {step:4d}/{MAX_STEPS}] Avg Loss: {running_loss / 100:.4f}")
                running_loss = 0.0

    if step >= MAX_STEPS:
        break

# ------------------------
# SAVE MODEL
# ------------------------
torch.save(model.state_dict(), "llama_su_pretrained.pt")
print("Model saved to 'llama_su_pretrained.pt'")

cuda:0
Epoch 0


Training Epoch 0:   3%|▎         | 101/3683 [00:10<06:02,  9.88it/s, avg_loss=0.69, loss=0.69, step=101]

[Step  100/40000] Avg Loss: 0.6889


Training Epoch 0:   5%|▌         | 201/3683 [00:20<05:52,  9.88it/s, avg_loss=0.686, loss=0.686, step=201]

[Step  200/40000] Avg Loss: 0.6884


Training Epoch 0:   8%|▊         | 301/3683 [00:30<05:41,  9.90it/s, avg_loss=0.693, loss=0.693, step=301]

[Step  300/40000] Avg Loss: 0.6870


Training Epoch 0:  11%|█         | 401/3683 [00:40<05:33,  9.85it/s, avg_loss=0.687, loss=0.687, step=401]

[Step  400/40000] Avg Loss: 0.6856


Training Epoch 0:  14%|█▎        | 501/3683 [00:50<05:19,  9.95it/s, avg_loss=0.678, loss=0.678, step=501]

[Step  500/40000] Avg Loss: 0.6826


Training Epoch 0:  16%|█▋        | 601/3683 [01:00<05:10,  9.92it/s, avg_loss=0.673, loss=0.673, step=601]

[Step  600/40000] Avg Loss: 0.6794


Training Epoch 0:  19%|█▉        | 702/3683 [01:10<04:50, 10.27it/s, avg_loss=0.669, loss=0.669, step=702]

[Step  700/40000] Avg Loss: 0.6762


Training Epoch 0:  22%|██▏       | 802/3683 [01:20<04:38, 10.34it/s, avg_loss=0.665, loss=0.665, step=802]

[Step  800/40000] Avg Loss: 0.6742


Training Epoch 0:  24%|██▍       | 902/3683 [01:29<04:26, 10.45it/s, avg_loss=0.669, loss=0.678, step=902]

[Step  900/40000] Avg Loss: 0.6713


Training Epoch 0:  27%|██▋       | 1002/3683 [01:39<04:21, 10.24it/s, avg_loss=0.667, loss=0.658, step=1002]

[Step 1000/40000] Avg Loss: 0.6672


Training Epoch 0:  30%|██▉       | 1102/3683 [01:48<04:05, 10.52it/s, avg_loss=0.664, loss=0.674, step=1102]

[Step 1100/40000] Avg Loss: 0.6636


Training Epoch 0:  33%|███▎      | 1202/3683 [01:58<04:00, 10.33it/s, avg_loss=0.651, loss=0.651, step=1202]

[Step 1200/40000] Avg Loss: 0.6629


Training Epoch 0:  35%|███▌      | 1302/3683 [02:08<03:48, 10.40it/s, avg_loss=0.66, loss=0.649, step=1302] 

[Step 1300/40000] Avg Loss: 0.6636


Training Epoch 0:  38%|███▊      | 1402/3683 [02:17<03:38, 10.44it/s, avg_loss=0.682, loss=0.67, step=1402] 

[Step 1400/40000] Avg Loss: 0.6612


Training Epoch 0:  41%|████      | 1502/3683 [02:27<03:26, 10.54it/s, avg_loss=0.656, loss=0.669, step=1502]

[Step 1500/40000] Avg Loss: 0.6588


Training Epoch 0:  43%|████▎     | 1602/3683 [02:36<03:20, 10.37it/s, avg_loss=0.655, loss=0.668, step=1602]

[Step 1600/40000] Avg Loss: 0.6578


Training Epoch 0:  46%|████▌     | 1702/3683 [02:46<03:11, 10.35it/s, avg_loss=0.653, loss=0.639, step=1702]

[Step 1700/40000] Avg Loss: 0.6554


Training Epoch 0:  49%|████▉     | 1802/3683 [02:56<03:02, 10.32it/s, avg_loss=0.651, loss=0.666, step=1802]

[Step 1800/40000] Avg Loss: 0.6517


Training Epoch 0:  52%|█████▏    | 1902/3683 [03:05<02:51, 10.40it/s, avg_loss=0.649, loss=0.665, step=1902]

[Step 1900/40000] Avg Loss: 0.6479


Training Epoch 0:  54%|█████▍    | 2002/3683 [03:15<02:41, 10.43it/s, avg_loss=0.648, loss=0.632, step=2002]

[Step 2000/40000] Avg Loss: 0.6481


Training Epoch 0:  57%|█████▋    | 2102/3683 [03:25<02:33, 10.31it/s, avg_loss=0.631, loss=0.631, step=2102]

[Step 2100/40000] Avg Loss: 0.6484


Training Epoch 0:  60%|█████▉    | 2202/3683 [03:34<02:22, 10.42it/s, avg_loss=0.63, loss=0.63, step=2202]  

[Step 2200/40000] Avg Loss: 0.6483


Training Epoch 0:  63%|██████▎   | 2302/3683 [03:44<02:12, 10.39it/s, avg_loss=0.662, loss=0.662, step=2302]

[Step 2300/40000] Avg Loss: 0.6467


Training Epoch 0:  65%|██████▌   | 2402/3683 [03:53<02:03, 10.37it/s, avg_loss=0.628, loss=0.628, step=2402]

[Step 2400/40000] Avg Loss: 0.6463


Training Epoch 0:  68%|██████▊   | 2502/3683 [04:03<01:52, 10.46it/s, avg_loss=0.644, loss=0.627, step=2502]

[Step 2500/40000] Avg Loss: 0.6487


Training Epoch 0:  71%|███████   | 2602/3683 [04:13<01:43, 10.43it/s, avg_loss=0.644, loss=0.626, step=2602]

[Step 2600/40000] Avg Loss: 0.6455


Training Epoch 0:  73%|███████▎  | 2702/3683 [04:22<01:34, 10.37it/s, avg_loss=0.643, loss=0.661, step=2702]

[Step 2700/40000] Avg Loss: 0.6466


Training Epoch 0:  76%|███████▌  | 2802/3683 [04:32<01:25, 10.29it/s, avg_loss=0.643, loss=0.66, step=2802] 

[Step 2800/40000] Avg Loss: 0.6397


Training Epoch 0:  79%|███████▉  | 2902/3683 [04:41<01:14, 10.50it/s, avg_loss=0.66, loss=0.66, step=2902]  

[Step 2900/40000] Avg Loss: 0.6454


Training Epoch 0:  82%|████████▏ | 3002/3683 [04:51<01:05, 10.42it/s, avg_loss=0.659, loss=0.659, step=3002]

[Step 3000/40000] Avg Loss: 0.6420


Training Epoch 0:  84%|████████▍ | 3102/3683 [05:01<00:56, 10.36it/s, avg_loss=0.622, loss=0.622, step=3102]

[Step 3100/40000] Avg Loss: 0.6405


Training Epoch 0:  87%|████████▋ | 3202/3683 [05:10<00:46, 10.35it/s, avg_loss=0.658, loss=0.696, step=3202]

[Step 3200/40000] Avg Loss: 0.6376


Training Epoch 0:  90%|████████▉ | 3302/3683 [05:20<00:36, 10.46it/s, avg_loss=0.677, loss=0.696, step=3302]

[Step 3300/40000] Avg Loss: 0.6393


Training Epoch 0:  92%|█████████▏| 3402/3683 [05:30<00:27, 10.40it/s, avg_loss=0.658, loss=0.696, step=3402]

[Step 3400/40000] Avg Loss: 0.6431


Training Epoch 0:  95%|█████████▌| 3502/3683 [05:39<00:17, 10.37it/s, avg_loss=0.619, loss=0.619, step=3502]

[Step 3500/40000] Avg Loss: 0.6390


Training Epoch 0:  98%|█████████▊| 3602/3683 [05:49<00:07, 10.44it/s, avg_loss=0.638, loss=0.619, step=3602]

[Step 3600/40000] Avg Loss: 0.6404


Training Epoch 0: 100%|██████████| 3683/3683 [05:57<00:00, 10.31it/s, avg_loss=0.642, loss=0.658, step=3683]


Epoch 1


Training Epoch 1:   1%|          | 19/3683 [00:01<05:56, 10.29it/s, avg_loss=0.619, loss=0.619, step=3702]

[Step 3700/40000] Avg Loss: 0.6418


Training Epoch 1:   3%|▎         | 119/3683 [00:10<05:08, 11.55it/s, avg_loss=0.638, loss=0.619, step=3802]

[Step 3800/40000] Avg Loss: 0.6392


Training Epoch 1:   6%|▌         | 219/3683 [00:19<04:58, 11.59it/s, avg_loss=0.638, loss=0.619, step=3902]

[Step 3900/40000] Avg Loss: 0.6388


Training Epoch 1:   9%|▊         | 319/3683 [00:28<04:50, 11.57it/s, avg_loss=0.619, loss=0.619, step=4002]

[Step 4000/40000] Avg Loss: 0.6438


Training Epoch 1:  11%|█▏        | 419/3683 [00:36<04:41, 11.58it/s, avg_loss=0.619, loss=0.619, step=4102]

[Step 4100/40000] Avg Loss: 0.6353


Training Epoch 1:  14%|█▍        | 519/3683 [00:45<04:33, 11.58it/s, avg_loss=0.619, loss=0.619, step=4202]

[Step 4200/40000] Avg Loss: 0.6372


Training Epoch 1:  17%|█▋        | 619/3683 [00:54<04:24, 11.58it/s, avg_loss=0.658, loss=0.696, step=4302]

[Step 4300/40000] Avg Loss: 0.6384


Training Epoch 1:  20%|█▉        | 719/3683 [01:02<04:15, 11.58it/s, avg_loss=0.638, loss=0.658, step=4402]

[Step 4400/40000] Avg Loss: 0.6368


Training Epoch 1:  22%|██▏       | 819/3683 [01:11<04:07, 11.59it/s, avg_loss=0.638, loss=0.619, step=4502]

[Step 4500/40000] Avg Loss: 0.6403


Training Epoch 1:  25%|██▍       | 919/3683 [01:20<03:59, 11.56it/s, avg_loss=0.677, loss=0.658, step=4602]

[Step 4600/40000] Avg Loss: 0.6411


Training Epoch 1:  28%|██▊       | 1019/3683 [01:28<03:49, 11.59it/s, avg_loss=0.638, loss=0.619, step=4702]

[Step 4700/40000] Avg Loss: 0.6372


Training Epoch 1:  30%|███       | 1119/3683 [01:37<03:41, 11.60it/s, avg_loss=0.658, loss=0.658, step=4802]

[Step 4800/40000] Avg Loss: 0.6395


Training Epoch 1:  33%|███▎      | 1219/3683 [01:46<03:32, 11.59it/s, avg_loss=0.638, loss=0.619, step=4902]

[Step 4900/40000] Avg Loss: 0.6403


Training Epoch 1:  36%|███▌      | 1319/3683 [01:54<03:24, 11.58it/s, avg_loss=0.638, loss=0.658, step=5002]

[Step 5000/40000] Avg Loss: 0.6407


Training Epoch 1:  39%|███▊      | 1419/3683 [02:03<03:15, 11.59it/s, avg_loss=0.638, loss=0.619, step=5102]

[Step 5100/40000] Avg Loss: 0.6384


Training Epoch 1:  41%|████      | 1519/3683 [02:11<03:07, 11.57it/s, avg_loss=0.638, loss=0.658, step=5202]

[Step 5200/40000] Avg Loss: 0.6415


Training Epoch 1:  44%|████▍     | 1619/3683 [02:20<02:58, 11.58it/s, avg_loss=0.638, loss=0.658, step=5302]

[Step 5300/40000] Avg Loss: 0.6426


Training Epoch 1:  47%|████▋     | 1719/3683 [02:29<02:49, 11.57it/s, avg_loss=0.619, loss=0.619, step=5402]

[Step 5400/40000] Avg Loss: 0.6426


Training Epoch 1:  49%|████▉     | 1819/3683 [02:37<02:40, 11.60it/s, avg_loss=0.638, loss=0.658, step=5502]

[Step 5500/40000] Avg Loss: 0.6372


Training Epoch 1:  52%|█████▏    | 1919/3683 [02:46<02:32, 11.60it/s, avg_loss=0.619, loss=0.619, step=5602]

[Step 5600/40000] Avg Loss: 0.6457


Training Epoch 1:  55%|█████▍    | 2019/3683 [02:55<02:23, 11.57it/s, avg_loss=0.638, loss=0.619, step=5702]

[Step 5700/40000] Avg Loss: 0.6361


Training Epoch 1:  58%|█████▊    | 2119/3683 [03:03<02:15, 11.56it/s, avg_loss=0.696, loss=0.696, step=5802]

[Step 5800/40000] Avg Loss: 0.6384


Training Epoch 1:  60%|██████    | 2219/3683 [03:12<02:06, 11.60it/s, avg_loss=0.638, loss=0.619, step=5902]

[Step 5900/40000] Avg Loss: 0.6407


Training Epoch 1:  63%|██████▎   | 2319/3683 [03:21<01:57, 11.59it/s, avg_loss=0.619, loss=0.619, step=6002]

[Step 6000/40000] Avg Loss: 0.6445


Training Epoch 1:  66%|██████▌   | 2419/3683 [03:29<01:49, 11.55it/s, avg_loss=0.619, loss=0.619, step=6102]

[Step 6100/40000] Avg Loss: 0.6380


Training Epoch 1:  68%|██████▊   | 2519/3683 [03:38<01:40, 11.59it/s, avg_loss=0.638, loss=0.619, step=6202]

[Step 6200/40000] Avg Loss: 0.6399


Training Epoch 1:  71%|███████   | 2619/3683 [03:47<01:32, 11.56it/s, avg_loss=0.619, loss=0.619, step=6302]

[Step 6300/40000] Avg Loss: 0.6361


Training Epoch 1:  74%|███████▍  | 2719/3683 [03:55<01:23, 11.58it/s, avg_loss=0.638, loss=0.619, step=6402]

[Step 6400/40000] Avg Loss: 0.6368


Training Epoch 1:  77%|███████▋  | 2819/3683 [04:04<01:14, 11.58it/s, avg_loss=0.619, loss=0.619, step=6502]

[Step 6500/40000] Avg Loss: 0.6388


Training Epoch 1:  79%|███████▉  | 2919/3683 [04:12<01:05, 11.58it/s, avg_loss=0.638, loss=0.619, step=6602]

[Step 6600/40000] Avg Loss: 0.6365


Training Epoch 1:  82%|████████▏ | 3019/3683 [04:21<00:57, 11.60it/s, avg_loss=0.638, loss=0.658, step=6702]

[Step 6700/40000] Avg Loss: 0.6372


Training Epoch 1:  85%|████████▍ | 3119/3683 [04:30<00:48, 11.59it/s, avg_loss=0.638, loss=0.619, step=6802]

[Step 6800/40000] Avg Loss: 0.6384


Training Epoch 1:  87%|████████▋ | 3219/3683 [04:38<00:40, 11.55it/s, avg_loss=0.638, loss=0.658, step=6902]

[Step 6900/40000] Avg Loss: 0.6388


Training Epoch 1:  90%|█████████ | 3319/3683 [04:47<00:31, 11.57it/s, avg_loss=0.638, loss=0.619, step=7002]

[Step 7000/40000] Avg Loss: 0.6399


Training Epoch 1:  93%|█████████▎| 3419/3683 [04:56<00:22, 11.57it/s, avg_loss=0.658, loss=0.696, step=7102]

[Step 7100/40000] Avg Loss: 0.6388


Training Epoch 1:  96%|█████████▌| 3519/3683 [05:04<00:14, 11.57it/s, avg_loss=0.658, loss=0.696, step=7202]

[Step 7200/40000] Avg Loss: 0.6372


Training Epoch 1:  98%|█████████▊| 3619/3683 [05:13<00:05, 11.59it/s, avg_loss=0.619, loss=0.619, step=7302]

[Step 7300/40000] Avg Loss: 0.6372


Training Epoch 1: 100%|██████████| 3683/3683 [05:19<00:00, 11.54it/s, avg_loss=0.637, loss=0.696, step=7366]


Epoch 2


Training Epoch 2:   1%|          | 35/3683 [00:03<05:15, 11.57it/s, avg_loss=0.638, loss=0.619, step=7402]

[Step 7400/40000] Avg Loss: 0.6361


Training Epoch 2:   4%|▎         | 135/3683 [00:11<05:06, 11.56it/s, avg_loss=0.715, loss=0.735, step=7502]

[Step 7500/40000] Avg Loss: 0.6415


Training Epoch 2:   6%|▋         | 235/3683 [00:20<04:57, 11.58it/s, avg_loss=0.619, loss=0.619, step=7602]

[Step 7600/40000] Avg Loss: 0.6403


Training Epoch 2:   9%|▉         | 335/3683 [00:29<04:49, 11.58it/s, avg_loss=0.619, loss=0.619, step=7702]

[Step 7700/40000] Avg Loss: 0.6403


Training Epoch 2:  12%|█▏        | 435/3683 [00:37<04:40, 11.56it/s, avg_loss=0.638, loss=0.658, step=7802]

[Step 7800/40000] Avg Loss: 0.6384


Training Epoch 2:  15%|█▍        | 535/3683 [00:46<04:32, 11.57it/s, avg_loss=0.638, loss=0.619, step=7902]

[Step 7900/40000] Avg Loss: 0.6361


Training Epoch 2:  17%|█▋        | 635/3683 [00:55<04:23, 11.56it/s, avg_loss=0.638, loss=0.658, step=8002]

[Step 8000/40000] Avg Loss: 0.6399


Training Epoch 2:  20%|█▉        | 735/3683 [01:03<04:15, 11.52it/s, avg_loss=0.677, loss=0.658, step=8102]

[Step 8100/40000] Avg Loss: 0.6384


Training Epoch 2:  23%|██▎       | 835/3683 [01:12<04:05, 11.58it/s, avg_loss=0.638, loss=0.619, step=8202]

[Step 8200/40000] Avg Loss: 0.6407


Training Epoch 2:  25%|██▌       | 935/3683 [01:21<03:57, 11.57it/s, avg_loss=0.638, loss=0.619, step=8302]

[Step 8300/40000] Avg Loss: 0.6334


Training Epoch 2:  28%|██▊       | 1035/3683 [01:29<03:48, 11.59it/s, avg_loss=0.658, loss=0.619, step=8402]

[Step 8400/40000] Avg Loss: 0.6376


Training Epoch 2:  31%|███       | 1135/3683 [01:38<03:40, 11.58it/s, avg_loss=0.638, loss=0.619, step=8502]

[Step 8500/40000] Avg Loss: 0.6442


Training Epoch 2:  34%|███▎      | 1235/3683 [01:47<03:31, 11.56it/s, avg_loss=0.638, loss=0.658, step=8602]

[Step 8600/40000] Avg Loss: 0.6418


Training Epoch 2:  36%|███▌      | 1335/3683 [01:55<03:23, 11.56it/s, avg_loss=0.638, loss=0.658, step=8702]

[Step 8700/40000] Avg Loss: 0.6445


Training Epoch 2:  39%|███▉      | 1435/3683 [02:04<03:14, 11.57it/s, avg_loss=0.638, loss=0.619, step=8802]

[Step 8800/40000] Avg Loss: 0.6399


Training Epoch 2:  42%|████▏     | 1535/3683 [02:12<03:05, 11.57it/s, avg_loss=0.638, loss=0.619, step=8902]

[Step 8900/40000] Avg Loss: 0.6395


Training Epoch 2:  44%|████▍     | 1635/3683 [02:21<02:56, 11.61it/s, avg_loss=0.638, loss=0.619, step=9002]

[Step 9000/40000] Avg Loss: 0.6372


Training Epoch 2:  47%|████▋     | 1735/3683 [02:30<02:48, 11.56it/s, avg_loss=0.619, loss=0.619, step=9102]

[Step 9100/40000] Avg Loss: 0.6445


Training Epoch 2:  50%|████▉     | 1835/3683 [02:38<02:39, 11.56it/s, avg_loss=0.619, loss=0.619, step=9202]

[Step 9200/40000] Avg Loss: 0.6399


Training Epoch 2:  53%|█████▎    | 1935/3683 [02:47<02:31, 11.58it/s, avg_loss=0.638, loss=0.658, step=9302]

[Step 9300/40000] Avg Loss: 0.6357


Training Epoch 2:  55%|█████▌    | 2035/3683 [02:56<02:22, 11.58it/s, avg_loss=0.638, loss=0.619, step=9402]

[Step 9400/40000] Avg Loss: 0.6368


Training Epoch 2:  58%|█████▊    | 2135/3683 [03:04<02:13, 11.57it/s, avg_loss=0.619, loss=0.619, step=9502]

[Step 9500/40000] Avg Loss: 0.6372


Training Epoch 2:  61%|██████    | 2235/3683 [03:13<02:05, 11.57it/s, avg_loss=0.619, loss=0.619, step=9602]

[Step 9600/40000] Avg Loss: 0.6415


Training Epoch 2:  63%|██████▎   | 2335/3683 [03:22<01:56, 11.56it/s, avg_loss=0.619, loss=0.619, step=9702]

[Step 9700/40000] Avg Loss: 0.6392


Training Epoch 2:  66%|██████▌   | 2435/3683 [03:30<01:47, 11.58it/s, avg_loss=0.619, loss=0.619, step=9802]

[Step 9800/40000] Avg Loss: 0.6422


Training Epoch 2:  69%|██████▉   | 2535/3683 [03:39<01:39, 11.59it/s, avg_loss=0.638, loss=0.658, step=9902]

[Step 9900/40000] Avg Loss: 0.6388


Training Epoch 2:  72%|███████▏  | 2635/3683 [03:48<01:30, 11.58it/s, avg_loss=0.638, loss=0.619, step=1e+4]

[Step 10000/40000] Avg Loss: 0.6368


Training Epoch 2:  74%|███████▍  | 2735/3683 [03:56<01:21, 11.58it/s, avg_loss=0.619, loss=0.619, step=10102]

[Step 10100/40000] Avg Loss: 0.6411


Training Epoch 2:  77%|███████▋  | 2835/3683 [04:05<01:13, 11.57it/s, avg_loss=0.638, loss=0.619, step=10202]

[Step 10200/40000] Avg Loss: 0.6357


Training Epoch 2:  80%|███████▉  | 2935/3683 [04:13<01:04, 11.56it/s, avg_loss=0.619, loss=0.619, step=10302]

[Step 10300/40000] Avg Loss: 0.6372


Training Epoch 2:  82%|████████▏ | 3035/3683 [04:22<00:55, 11.59it/s, avg_loss=0.677, loss=0.696, step=10402]

[Step 10400/40000] Avg Loss: 0.6403


Training Epoch 2:  85%|████████▌ | 3135/3683 [04:31<00:47, 11.59it/s, avg_loss=0.619, loss=0.619, step=10502]

[Step 10500/40000] Avg Loss: 0.6418


Training Epoch 2:  88%|████████▊ | 3235/3683 [04:39<00:38, 11.56it/s, avg_loss=0.658, loss=0.658, step=10602]

[Step 10600/40000] Avg Loss: 0.6349


Training Epoch 2:  91%|█████████ | 3335/3683 [04:48<00:30, 11.59it/s, avg_loss=0.658, loss=0.696, step=10702]

[Step 10700/40000] Avg Loss: 0.6372


Training Epoch 2:  93%|█████████▎| 3435/3683 [04:57<00:21, 11.57it/s, avg_loss=0.696, loss=0.696, step=10802]

[Step 10800/40000] Avg Loss: 0.6384


Training Epoch 2:  96%|█████████▌| 3535/3683 [05:05<00:12, 11.57it/s, avg_loss=0.638, loss=0.619, step=10902]

[Step 10900/40000] Avg Loss: 0.6380


Training Epoch 2:  99%|█████████▊| 3635/3683 [05:14<00:04, 11.60it/s, avg_loss=0.619, loss=0.619, step=11002]

[Step 11000/40000] Avg Loss: 0.6416


Training Epoch 2: 100%|██████████| 3683/3683 [05:18<00:00, 11.56it/s, avg_loss=0.635, loss=0.658, step=11049]


Epoch 3


Training Epoch 3:   1%|▏         | 53/3683 [00:04<05:13, 11.59it/s, avg_loss=0.658, loss=0.658, step=11102]

[Step 11100/40000] Avg Loss: 0.6368


Training Epoch 3:   4%|▍         | 153/3683 [00:13<05:04, 11.58it/s, avg_loss=0.638, loss=0.619, step=11202]

[Step 11200/40000] Avg Loss: 0.6388


Training Epoch 3:   7%|▋         | 253/3683 [00:21<04:55, 11.60it/s, avg_loss=0.619, loss=0.619, step=11302]

[Step 11300/40000] Avg Loss: 0.6345


Training Epoch 3:  10%|▉         | 353/3683 [00:30<04:47, 11.59it/s, avg_loss=0.619, loss=0.619, step=11402]

[Step 11400/40000] Avg Loss: 0.6376


Training Epoch 3:  12%|█▏        | 453/3683 [00:39<04:39, 11.56it/s, avg_loss=0.619, loss=0.619, step=11502]

[Step 11500/40000] Avg Loss: 0.6403


Training Epoch 3:  15%|█▌        | 553/3683 [00:47<04:29, 11.60it/s, avg_loss=0.619, loss=0.619, step=11602]

[Step 11600/40000] Avg Loss: 0.6388


Training Epoch 3:  18%|█▊        | 653/3683 [00:56<04:21, 11.57it/s, avg_loss=0.677, loss=0.696, step=11702]

[Step 11700/40000] Avg Loss: 0.6384


Training Epoch 3:  20%|██        | 753/3683 [01:05<04:12, 11.59it/s, avg_loss=0.658, loss=0.619, step=11802]

[Step 11800/40000] Avg Loss: 0.6399


Training Epoch 3:  23%|██▎       | 853/3683 [01:13<04:05, 11.55it/s, avg_loss=0.658, loss=0.696, step=11902]

[Step 11900/40000] Avg Loss: 0.6438


Training Epoch 3:  26%|██▌       | 953/3683 [01:22<03:55, 11.58it/s, avg_loss=0.658, loss=0.658, step=12002]

[Step 12000/40000] Avg Loss: 0.6434


Training Epoch 3:  29%|██▊       | 1053/3683 [01:31<03:47, 11.58it/s, avg_loss=0.638, loss=0.658, step=12102]

[Step 12100/40000] Avg Loss: 0.6418


Training Epoch 3:  31%|███▏      | 1153/3683 [01:39<03:38, 11.55it/s, avg_loss=0.638, loss=0.658, step=12202]

[Step 12200/40000] Avg Loss: 0.6361


Training Epoch 3:  34%|███▍      | 1253/3683 [01:48<03:30, 11.56it/s, avg_loss=0.619, loss=0.619, step=12302]

[Step 12300/40000] Avg Loss: 0.6411


Training Epoch 3:  37%|███▋      | 1353/3683 [01:57<03:21, 11.58it/s, avg_loss=0.619, loss=0.619, step=12402]

[Step 12400/40000] Avg Loss: 0.6357


Training Epoch 3:  39%|███▉      | 1453/3683 [02:05<03:12, 11.56it/s, avg_loss=0.619, loss=0.619, step=12502]

[Step 12500/40000] Avg Loss: 0.6415


Training Epoch 3:  42%|████▏     | 1553/3683 [02:14<03:04, 11.56it/s, avg_loss=0.638, loss=0.658, step=12602]

[Step 12600/40000] Avg Loss: 0.6407


Training Epoch 3:  45%|████▍     | 1653/3683 [02:23<02:55, 11.58it/s, avg_loss=0.638, loss=0.658, step=12702]

[Step 12700/40000] Avg Loss: 0.6411


Training Epoch 3:  48%|████▊     | 1753/3683 [02:31<02:46, 11.57it/s, avg_loss=0.638, loss=0.619, step=12802]

[Step 12800/40000] Avg Loss: 0.6399


Training Epoch 3:  50%|█████     | 1853/3683 [02:40<02:38, 11.57it/s, avg_loss=0.677, loss=0.658, step=12902]

[Step 12900/40000] Avg Loss: 0.6411


Training Epoch 3:  53%|█████▎    | 1953/3683 [02:48<02:29, 11.58it/s, avg_loss=0.696, loss=0.696, step=13002]

[Step 13000/40000] Avg Loss: 0.6418


Training Epoch 3:  56%|█████▌    | 2053/3683 [02:57<02:20, 11.58it/s, avg_loss=0.619, loss=0.619, step=13102]

[Step 13100/40000] Avg Loss: 0.6376


Training Epoch 3:  58%|█████▊    | 2153/3683 [03:06<02:12, 11.56it/s, avg_loss=0.619, loss=0.619, step=13202]

[Step 13200/40000] Avg Loss: 0.6411


Training Epoch 3:  61%|██████    | 2253/3683 [03:14<02:03, 11.56it/s, avg_loss=0.638, loss=0.658, step=13302]

[Step 13300/40000] Avg Loss: 0.6384


Training Epoch 3:  64%|██████▍   | 2353/3683 [03:23<01:55, 11.55it/s, avg_loss=0.619, loss=0.619, step=13402]

[Step 13400/40000] Avg Loss: 0.6388


Training Epoch 3:  67%|██████▋   | 2453/3683 [03:32<01:46, 11.56it/s, avg_loss=0.619, loss=0.619, step=13502]

[Step 13500/40000] Avg Loss: 0.6349


Training Epoch 3:  69%|██████▉   | 2553/3683 [03:40<01:37, 11.58it/s, avg_loss=0.677, loss=0.696, step=13602]

[Step 13600/40000] Avg Loss: 0.6380


Training Epoch 3:  72%|███████▏  | 2653/3683 [03:49<01:28, 11.58it/s, avg_loss=0.677, loss=0.696, step=13702]

[Step 13700/40000] Avg Loss: 0.6392


Training Epoch 3:  75%|███████▍  | 2753/3683 [03:58<01:20, 11.61it/s, avg_loss=0.619, loss=0.619, step=13802]

[Step 13800/40000] Avg Loss: 0.6430


Training Epoch 3:  77%|███████▋  | 2853/3683 [04:06<01:11, 11.58it/s, avg_loss=0.619, loss=0.619, step=13902]

[Step 13900/40000] Avg Loss: 0.6361


Training Epoch 3:  80%|████████  | 2953/3683 [04:15<01:03, 11.56it/s, avg_loss=0.619, loss=0.619, step=14002]

[Step 14000/40000] Avg Loss: 0.6345


Training Epoch 3:  83%|████████▎ | 3053/3683 [04:24<00:54, 11.58it/s, avg_loss=0.619, loss=0.619, step=14102]

[Step 14100/40000] Avg Loss: 0.6361


Training Epoch 3:  86%|████████▌ | 3153/3683 [04:32<00:45, 11.59it/s, avg_loss=0.619, loss=0.619, step=14202]

[Step 14200/40000] Avg Loss: 0.6361


Training Epoch 3:  88%|████████▊ | 3253/3683 [04:41<00:37, 11.56it/s, avg_loss=0.619, loss=0.619, step=14302]

[Step 14300/40000] Avg Loss: 0.6349


Training Epoch 3:  91%|█████████ | 3353/3683 [04:49<00:28, 11.57it/s, avg_loss=0.638, loss=0.658, step=14402]

[Step 14400/40000] Avg Loss: 0.6426


Training Epoch 3:  94%|█████████▍| 3453/3683 [04:58<00:19, 11.60it/s, avg_loss=0.638, loss=0.619, step=14502]

[Step 14500/40000] Avg Loss: 0.6418


Training Epoch 3:  96%|█████████▋| 3553/3683 [05:07<00:11, 11.58it/s, avg_loss=0.658, loss=0.658, step=14602]

[Step 14600/40000] Avg Loss: 0.6399


Training Epoch 3:  99%|█████████▉| 3653/3683 [05:15<00:02, 11.58it/s, avg_loss=0.619, loss=0.619, step=14702]

[Step 14700/40000] Avg Loss: 0.6411


Training Epoch 3: 100%|██████████| 3683/3683 [05:18<00:00, 11.56it/s, avg_loss=0.635, loss=0.619, step=14732]


Epoch 4


Training Epoch 4:   2%|▏         | 69/3683 [00:06<05:11, 11.61it/s, avg_loss=0.658, loss=0.658, step=14802]

[Step 14800/40000] Avg Loss: 0.6368


Training Epoch 4:   5%|▍         | 169/3683 [00:14<05:02, 11.61it/s, avg_loss=0.638, loss=0.658, step=14902]

[Step 14900/40000] Avg Loss: 0.6399


Training Epoch 4:   7%|▋         | 269/3683 [00:23<04:54, 11.59it/s, avg_loss=0.619, loss=0.619, step=15002]

[Step 15000/40000] Avg Loss: 0.6403


Training Epoch 4:  10%|█         | 369/3683 [00:32<04:46, 11.59it/s, avg_loss=0.638, loss=0.658, step=15102]

[Step 15100/40000] Avg Loss: 0.6361


Training Epoch 4:  13%|█▎        | 469/3683 [00:40<04:37, 11.60it/s, avg_loss=0.658, loss=0.696, step=15202]

[Step 15200/40000] Avg Loss: 0.6411


Training Epoch 4:  15%|█▌        | 569/3683 [00:49<04:27, 11.62it/s, avg_loss=0.638, loss=0.619, step=15302]

[Step 15300/40000] Avg Loss: 0.6384


Training Epoch 4:  18%|█▊        | 669/3683 [00:57<04:20, 11.56it/s, avg_loss=0.658, loss=0.658, step=15402]

[Step 15400/40000] Avg Loss: 0.6368


Training Epoch 4:  21%|██        | 769/3683 [01:06<04:10, 11.61it/s, avg_loss=0.638, loss=0.658, step=15502]

[Step 15500/40000] Avg Loss: 0.6430


Training Epoch 4:  24%|██▎       | 869/3683 [01:15<04:02, 11.58it/s, avg_loss=0.658, loss=0.658, step=15602]

[Step 15600/40000] Avg Loss: 0.6407


Training Epoch 4:  26%|██▋       | 969/3683 [01:23<03:54, 11.59it/s, avg_loss=0.619, loss=0.619, step=15702]

[Step 15700/40000] Avg Loss: 0.6438


Training Epoch 4:  29%|██▉       | 1069/3683 [01:32<03:45, 11.61it/s, avg_loss=0.638, loss=0.619, step=15802]

[Step 15800/40000] Avg Loss: 0.6380


Training Epoch 4:  32%|███▏      | 1169/3683 [01:41<03:37, 11.57it/s, avg_loss=0.619, loss=0.619, step=15902]

[Step 15900/40000] Avg Loss: 0.6353


Training Epoch 4:  34%|███▍      | 1269/3683 [01:49<03:28, 11.59it/s, avg_loss=0.638, loss=0.658, step=16002]

[Step 16000/40000] Avg Loss: 0.6361


Training Epoch 4:  37%|███▋      | 1369/3683 [01:58<03:19, 11.59it/s, avg_loss=0.638, loss=0.619, step=16102]

[Step 16100/40000] Avg Loss: 0.6403


Training Epoch 4:  40%|███▉      | 1469/3683 [02:07<03:11, 11.58it/s, avg_loss=0.619, loss=0.619, step=16202]

[Step 16200/40000] Avg Loss: 0.6365


Training Epoch 4:  43%|████▎     | 1569/3683 [02:15<03:02, 11.58it/s, avg_loss=0.696, loss=0.735, step=16302]

[Step 16300/40000] Avg Loss: 0.6349


Training Epoch 4:  45%|████▌     | 1669/3683 [02:24<02:54, 11.57it/s, avg_loss=0.638, loss=0.658, step=16402]

[Step 16400/40000] Avg Loss: 0.6411


Training Epoch 4:  48%|████▊     | 1769/3683 [02:32<02:45, 11.59it/s, avg_loss=0.638, loss=0.619, step=16502]

[Step 16500/40000] Avg Loss: 0.6395


Training Epoch 4:  51%|█████     | 1869/3683 [02:41<02:36, 11.61it/s, avg_loss=0.638, loss=0.619, step=16602]

[Step 16600/40000] Avg Loss: 0.6354


Training Epoch 4:  53%|█████▎    | 1969/3683 [02:50<02:27, 11.60it/s, avg_loss=0.619, loss=0.619, step=16702]

[Step 16700/40000] Avg Loss: 0.6380


Training Epoch 4:  56%|█████▌    | 2069/3683 [02:58<02:19, 11.60it/s, avg_loss=0.638, loss=0.619, step=16802]

[Step 16800/40000] Avg Loss: 0.6334


Training Epoch 4:  59%|█████▉    | 2169/3683 [03:07<02:10, 11.60it/s, avg_loss=0.638, loss=0.619, step=16902]

[Step 16900/40000] Avg Loss: 0.6442


Training Epoch 4:  62%|██████▏   | 2269/3683 [03:16<02:02, 11.58it/s, avg_loss=0.658, loss=0.658, step=17002]

[Step 17000/40000] Avg Loss: 0.6399


Training Epoch 4:  64%|██████▍   | 2369/3683 [03:24<01:53, 11.60it/s, avg_loss=0.638, loss=0.619, step=17102]

[Step 17100/40000] Avg Loss: 0.6422


Training Epoch 4:  67%|██████▋   | 2469/3683 [03:33<01:44, 11.59it/s, avg_loss=0.638, loss=0.658, step=17202]

[Step 17200/40000] Avg Loss: 0.6365


Training Epoch 4:  70%|██████▉   | 2569/3683 [03:41<01:36, 11.58it/s, avg_loss=0.619, loss=0.619, step=17302]

[Step 17300/40000] Avg Loss: 0.6395


Training Epoch 4:  72%|███████▏  | 2669/3683 [03:50<01:27, 11.62it/s, avg_loss=0.638, loss=0.658, step=17402]

[Step 17400/40000] Avg Loss: 0.6392


Training Epoch 4:  75%|███████▌  | 2769/3683 [03:59<01:18, 11.61it/s, avg_loss=0.619, loss=0.619, step=17502]

[Step 17500/40000] Avg Loss: 0.6457


Training Epoch 4:  78%|███████▊  | 2869/3683 [04:07<01:10, 11.58it/s, avg_loss=0.658, loss=0.658, step=17602]

[Step 17600/40000] Avg Loss: 0.6368


Training Epoch 4:  81%|████████  | 2969/3683 [04:16<01:01, 11.61it/s, avg_loss=0.638, loss=0.658, step=17702]

[Step 17700/40000] Avg Loss: 0.6399


Training Epoch 4:  83%|████████▎ | 3069/3683 [04:25<00:52, 11.59it/s, avg_loss=0.658, loss=0.696, step=17802]

[Step 17800/40000] Avg Loss: 0.6407


Training Epoch 4:  86%|████████▌ | 3169/3683 [04:33<00:44, 11.59it/s, avg_loss=0.677, loss=0.658, step=17902]

[Step 17900/40000] Avg Loss: 0.6430


Training Epoch 4:  89%|████████▉ | 3269/3683 [04:42<00:35, 11.56it/s, avg_loss=0.658, loss=0.658, step=18002]

[Step 18000/40000] Avg Loss: 0.6384


Training Epoch 4:  91%|█████████▏| 3369/3683 [04:51<00:27, 11.59it/s, avg_loss=0.619, loss=0.619, step=18102]

[Step 18100/40000] Avg Loss: 0.6392


Training Epoch 4:  94%|█████████▍| 3469/3683 [04:59<00:18, 11.59it/s, avg_loss=0.638, loss=0.619, step=18202]

[Step 18200/40000] Avg Loss: 0.6392


Training Epoch 4:  97%|█████████▋| 3569/3683 [05:08<00:09, 11.59it/s, avg_loss=0.677, loss=0.735, step=18302]

[Step 18300/40000] Avg Loss: 0.6392


Training Epoch 4: 100%|█████████▉| 3669/3683 [05:16<00:01, 11.55it/s, avg_loss=0.619, loss=0.619, step=18402]

[Step 18400/40000] Avg Loss: 0.6376


Training Epoch 4: 100%|██████████| 3683/3683 [05:18<00:00, 11.58it/s, avg_loss=0.64, loss=0.619, step=18415] 


Epoch 5


Training Epoch 5:   2%|▏         | 87/3683 [00:07<05:11, 11.55it/s, avg_loss=0.619, loss=0.619, step=18502]

[Step 18500/40000] Avg Loss: 0.6418


Training Epoch 5:   5%|▌         | 187/3683 [00:16<05:00, 11.62it/s, avg_loss=0.638, loss=0.619, step=18602]

[Step 18600/40000] Avg Loss: 0.6392


Training Epoch 5:   8%|▊         | 287/3683 [00:24<04:52, 11.62it/s, avg_loss=0.658, loss=0.619, step=18702]

[Step 18700/40000] Avg Loss: 0.6349


Training Epoch 5:  11%|█         | 387/3683 [00:33<04:44, 11.57it/s, avg_loss=0.619, loss=0.619, step=18802]

[Step 18800/40000] Avg Loss: 0.6411


Training Epoch 5:  13%|█▎        | 487/3683 [00:42<04:36, 11.58it/s, avg_loss=0.638, loss=0.658, step=18902]

[Step 18900/40000] Avg Loss: 0.6426


Training Epoch 5:  16%|█▌        | 587/3683 [00:50<04:27, 11.59it/s, avg_loss=0.638, loss=0.658, step=19002]

[Step 19000/40000] Avg Loss: 0.6407


Training Epoch 5:  19%|█▊        | 687/3683 [00:59<04:18, 11.58it/s, avg_loss=0.658, loss=0.619, step=19102]

[Step 19100/40000] Avg Loss: 0.6357


Training Epoch 5:  21%|██▏       | 787/3683 [01:08<04:10, 11.55it/s, avg_loss=0.658, loss=0.658, step=19202]

[Step 19200/40000] Avg Loss: 0.6403


Training Epoch 5:  24%|██▍       | 887/3683 [01:16<04:01, 11.59it/s, avg_loss=0.638, loss=0.658, step=19302]

[Step 19300/40000] Avg Loss: 0.6366


Training Epoch 5:  27%|██▋       | 987/3683 [01:25<03:52, 11.59it/s, avg_loss=0.619, loss=0.619, step=19402]

[Step 19400/40000] Avg Loss: 0.6368


Training Epoch 5:  30%|██▉       | 1087/3683 [01:33<03:44, 11.57it/s, avg_loss=0.619, loss=0.619, step=19502]

[Step 19500/40000] Avg Loss: 0.6380


Training Epoch 5:  32%|███▏      | 1187/3683 [01:42<03:36, 11.55it/s, avg_loss=0.658, loss=0.619, step=19602]

[Step 19600/40000] Avg Loss: 0.6403


Training Epoch 5:  35%|███▍      | 1287/3683 [01:51<03:26, 11.58it/s, avg_loss=0.619, loss=0.619, step=19702]

[Step 19700/40000] Avg Loss: 0.6380


Training Epoch 5:  38%|███▊      | 1387/3683 [01:59<03:18, 11.57it/s, avg_loss=0.638, loss=0.619, step=19802]

[Step 19800/40000] Avg Loss: 0.6341


Training Epoch 5:  40%|████      | 1487/3683 [02:08<03:10, 11.54it/s, avg_loss=0.638, loss=0.658, step=19902]

[Step 19900/40000] Avg Loss: 0.6392


Training Epoch 5:  43%|████▎     | 1587/3683 [02:17<03:00, 11.58it/s, avg_loss=0.619, loss=0.619, step=2e+4] 

[Step 20000/40000] Avg Loss: 0.6422


Training Epoch 5:  46%|████▌     | 1687/3683 [02:25<02:52, 11.59it/s, avg_loss=0.638, loss=0.619, step=20102]

[Step 20100/40000] Avg Loss: 0.6376


Training Epoch 5:  49%|████▊     | 1787/3683 [02:34<02:43, 11.58it/s, avg_loss=0.638, loss=0.619, step=20202]

[Step 20200/40000] Avg Loss: 0.6365


Training Epoch 5:  51%|█████     | 1887/3683 [02:43<02:35, 11.58it/s, avg_loss=0.619, loss=0.619, step=20302]

[Step 20300/40000] Avg Loss: 0.6415


Training Epoch 5:  54%|█████▍    | 1987/3683 [02:51<02:26, 11.59it/s, avg_loss=0.638, loss=0.619, step=20402]

[Step 20400/40000] Avg Loss: 0.6411


Training Epoch 5:  57%|█████▋    | 2087/3683 [03:00<02:17, 11.58it/s, avg_loss=0.638, loss=0.658, step=20502]

[Step 20500/40000] Avg Loss: 0.6349


Training Epoch 5:  59%|█████▉    | 2187/3683 [03:08<02:09, 11.58it/s, avg_loss=0.619, loss=0.619, step=20602]

[Step 20600/40000] Avg Loss: 0.6365


Training Epoch 5:  62%|██████▏   | 2287/3683 [03:17<02:00, 11.61it/s, avg_loss=0.619, loss=0.619, step=20702]

[Step 20700/40000] Avg Loss: 0.6380


Training Epoch 5:  65%|██████▍   | 2387/3683 [03:26<01:51, 11.59it/s, avg_loss=0.619, loss=0.619, step=20802]

[Step 20800/40000] Avg Loss: 0.6376


Training Epoch 5:  68%|██████▊   | 2487/3683 [03:34<01:43, 11.58it/s, avg_loss=0.677, loss=0.658, step=20902]

[Step 20900/40000] Avg Loss: 0.6438


Training Epoch 5:  70%|███████   | 2587/3683 [03:43<01:34, 11.59it/s, avg_loss=0.638, loss=0.658, step=21002]

[Step 21000/40000] Avg Loss: 0.6395


Training Epoch 5:  73%|███████▎  | 2687/3683 [03:52<01:25, 11.60it/s, avg_loss=0.638, loss=0.619, step=21102]

[Step 21100/40000] Avg Loss: 0.6368


Training Epoch 5:  76%|███████▌  | 2787/3683 [04:00<01:17, 11.59it/s, avg_loss=0.638, loss=0.619, step=21202]

[Step 21200/40000] Avg Loss: 0.6392


Training Epoch 5:  78%|███████▊  | 2887/3683 [04:09<01:08, 11.58it/s, avg_loss=0.619, loss=0.619, step=21302]

[Step 21300/40000] Avg Loss: 0.6388


Training Epoch 5:  81%|████████  | 2987/3683 [04:18<01:00, 11.59it/s, avg_loss=0.619, loss=0.619, step=21402]

[Step 21400/40000] Avg Loss: 0.6426


Training Epoch 5:  84%|████████▍ | 3087/3683 [04:26<00:51, 11.58it/s, avg_loss=0.619, loss=0.619, step=21502]

[Step 21500/40000] Avg Loss: 0.6399


Training Epoch 5:  87%|████████▋ | 3187/3683 [04:35<00:42, 11.61it/s, avg_loss=0.619, loss=0.619, step=21602]

[Step 21600/40000] Avg Loss: 0.6426


Training Epoch 5:  89%|████████▉ | 3287/3683 [04:43<00:34, 11.60it/s, avg_loss=0.696, loss=0.696, step=21702]

[Step 21700/40000] Avg Loss: 0.6376


Training Epoch 5:  92%|█████████▏| 3387/3683 [04:52<00:25, 11.57it/s, avg_loss=0.619, loss=0.619, step=21802]

[Step 21800/40000] Avg Loss: 0.6384


Training Epoch 5:  95%|█████████▍| 3487/3683 [05:01<00:16, 11.59it/s, avg_loss=0.638, loss=0.619, step=21902]

[Step 21900/40000] Avg Loss: 0.6388


Training Epoch 5:  97%|█████████▋| 3587/3683 [05:09<00:08, 11.58it/s, avg_loss=0.638, loss=0.619, step=22002]

[Step 22000/40000] Avg Loss: 0.6403


Training Epoch 5: 100%|██████████| 3683/3683 [05:18<00:00, 11.57it/s, avg_loss=0.645, loss=0.658, step=22098]


Epoch 6


Training Epoch 6:   0%|          | 3/3683 [00:00<07:32,  8.14it/s, avg_loss=0.638, loss=0.658, step=22102]

[Step 22100/40000] Avg Loss: 0.6442


Training Epoch 6:   3%|▎         | 103/3683 [00:09<05:09, 11.56it/s, avg_loss=0.638, loss=0.619, step=22202]

[Step 22200/40000] Avg Loss: 0.6403


Training Epoch 6:   6%|▌         | 203/3683 [00:17<05:01, 11.53it/s, avg_loss=0.619, loss=0.619, step=22302]

[Step 22300/40000] Avg Loss: 0.6361


Training Epoch 6:   8%|▊         | 303/3683 [00:26<04:53, 11.53it/s, avg_loss=0.619, loss=0.619, step=22402]

[Step 22400/40000] Avg Loss: 0.6365


Training Epoch 6:  11%|█         | 403/3683 [00:35<04:44, 11.54it/s, avg_loss=0.619, loss=0.619, step=22502]

[Step 22500/40000] Avg Loss: 0.6392


Training Epoch 6:  14%|█▎        | 503/3683 [00:43<04:35, 11.54it/s, avg_loss=0.619, loss=0.619, step=22602]

[Step 22600/40000] Avg Loss: 0.6357


Training Epoch 6:  16%|█▋        | 603/3683 [00:52<04:26, 11.56it/s, avg_loss=0.638, loss=0.619, step=22702]

[Step 22700/40000] Avg Loss: 0.6392


Training Epoch 6:  19%|█▉        | 703/3683 [01:01<04:17, 11.56it/s, avg_loss=0.677, loss=0.658, step=22802]

[Step 22800/40000] Avg Loss: 0.6384


Training Epoch 6:  22%|██▏       | 803/3683 [01:09<04:08, 11.58it/s, avg_loss=0.619, loss=0.619, step=22902]

[Step 22900/40000] Avg Loss: 0.6430


Training Epoch 6:  25%|██▍       | 903/3683 [01:18<04:00, 11.54it/s, avg_loss=0.619, loss=0.619, step=23002]

[Step 23000/40000] Avg Loss: 0.6399


Training Epoch 6:  27%|██▋       | 1003/3683 [01:27<03:51, 11.57it/s, avg_loss=0.658, loss=0.619, step=23102]

[Step 23100/40000] Avg Loss: 0.6380


Training Epoch 6:  30%|██▉       | 1103/3683 [01:35<03:43, 11.56it/s, avg_loss=0.658, loss=0.696, step=23202]

[Step 23200/40000] Avg Loss: 0.6365


Training Epoch 6:  33%|███▎      | 1203/3683 [01:44<03:34, 11.54it/s, avg_loss=0.638, loss=0.619, step=23302]

[Step 23300/40000] Avg Loss: 0.6376


Training Epoch 6:  35%|███▌      | 1303/3683 [01:53<03:25, 11.55it/s, avg_loss=0.619, loss=0.619, step=23402]

[Step 23400/40000] Avg Loss: 0.6426


Training Epoch 6:  38%|███▊      | 1403/3683 [02:01<03:17, 11.56it/s, avg_loss=0.658, loss=0.619, step=23502]

[Step 23500/40000] Avg Loss: 0.6415


Training Epoch 6:  41%|████      | 1503/3683 [02:10<03:08, 11.59it/s, avg_loss=0.638, loss=0.619, step=23602]

[Step 23600/40000] Avg Loss: 0.6438


Training Epoch 6:  44%|████▎     | 1603/3683 [02:18<03:00, 11.55it/s, avg_loss=0.619, loss=0.619, step=23702]

[Step 23700/40000] Avg Loss: 0.6368


Training Epoch 6:  46%|████▌     | 1703/3683 [02:27<02:51, 11.55it/s, avg_loss=0.619, loss=0.619, step=23802]

[Step 23800/40000] Avg Loss: 0.6399


Training Epoch 6:  49%|████▉     | 1803/3683 [02:36<02:42, 11.58it/s, avg_loss=0.638, loss=0.658, step=23902]

[Step 23900/40000] Avg Loss: 0.6365


Training Epoch 6:  52%|█████▏    | 1903/3683 [02:44<02:33, 11.59it/s, avg_loss=0.619, loss=0.619, step=24002]

[Step 24000/40000] Avg Loss: 0.6372


Training Epoch 6:  54%|█████▍    | 2003/3683 [02:53<02:25, 11.53it/s, avg_loss=0.638, loss=0.619, step=24102]

[Step 24100/40000] Avg Loss: 0.6426


Training Epoch 6:  57%|█████▋    | 2103/3683 [03:02<02:16, 11.57it/s, avg_loss=0.638, loss=0.619, step=24202]

[Step 24200/40000] Avg Loss: 0.6368


Training Epoch 6:  60%|█████▉    | 2203/3683 [03:10<02:07, 11.57it/s, avg_loss=0.658, loss=0.658, step=24302]

[Step 24300/40000] Avg Loss: 0.6345


Training Epoch 6:  63%|██████▎   | 2303/3683 [03:19<01:59, 11.54it/s, avg_loss=0.677, loss=0.658, step=24402]

[Step 24400/40000] Avg Loss: 0.6430


Training Epoch 6:  65%|██████▌   | 2403/3683 [03:28<01:50, 11.60it/s, avg_loss=0.658, loss=0.658, step=24502]

[Step 24500/40000] Avg Loss: 0.6418


Training Epoch 6:  68%|██████▊   | 2503/3683 [03:36<01:41, 11.62it/s, avg_loss=0.638, loss=0.619, step=24602]

[Step 24600/40000] Avg Loss: 0.6395


Training Epoch 6:  71%|███████   | 2603/3683 [03:45<01:32, 11.62it/s, avg_loss=0.619, loss=0.619, step=24702]

[Step 24700/40000] Avg Loss: 0.6372


Training Epoch 6:  73%|███████▎  | 2703/3683 [03:54<01:24, 11.60it/s, avg_loss=0.638, loss=0.619, step=24802]

[Step 24800/40000] Avg Loss: 0.6388


Training Epoch 6:  76%|███████▌  | 2803/3683 [04:02<01:15, 11.60it/s, avg_loss=0.658, loss=0.658, step=24902]

[Step 24900/40000] Avg Loss: 0.6397


Training Epoch 6:  79%|███████▉  | 2903/3683 [04:11<01:07, 11.61it/s, avg_loss=0.638, loss=0.658, step=25002]

[Step 25000/40000] Avg Loss: 0.6407


Training Epoch 6:  82%|████████▏ | 3003/3683 [04:19<00:58, 11.59it/s, avg_loss=0.619, loss=0.619, step=25102]

[Step 25100/40000] Avg Loss: 0.6384


Training Epoch 6:  84%|████████▍ | 3103/3683 [04:28<00:49, 11.60it/s, avg_loss=0.658, loss=0.658, step=25202]

[Step 25200/40000] Avg Loss: 0.6392


Training Epoch 6:  87%|████████▋ | 3203/3683 [04:37<00:41, 11.61it/s, avg_loss=0.638, loss=0.619, step=25302]

[Step 25300/40000] Avg Loss: 0.6465


Training Epoch 6:  90%|████████▉ | 3303/3683 [04:45<00:32, 11.63it/s, avg_loss=0.658, loss=0.696, step=25402]

[Step 25400/40000] Avg Loss: 0.6392


Training Epoch 6:  92%|█████████▏| 3403/3683 [04:54<00:24, 11.61it/s, avg_loss=0.619, loss=0.619, step=25502]

[Step 25500/40000] Avg Loss: 0.6407


Training Epoch 6:  95%|█████████▌| 3503/3683 [05:02<00:15, 11.62it/s, avg_loss=0.619, loss=0.619, step=25602]

[Step 25600/40000] Avg Loss: 0.6365


Training Epoch 6:  98%|█████████▊| 3603/3683 [05:11<00:06, 11.60it/s, avg_loss=0.658, loss=0.619, step=25702]

[Step 25700/40000] Avg Loss: 0.6330


Training Epoch 6: 100%|██████████| 3683/3683 [05:18<00:00, 11.57it/s, avg_loss=0.642, loss=0.619, step=25781]


Epoch 7


Training Epoch 7:   1%|          | 21/3683 [00:01<05:18, 11.51it/s, avg_loss=0.619, loss=0.619, step=25802]

[Step 25800/40000] Avg Loss: 0.6411


Training Epoch 7:   3%|▎         | 121/3683 [00:10<05:06, 11.64it/s, avg_loss=0.638, loss=0.658, step=25902]

[Step 25900/40000] Avg Loss: 0.6399


Training Epoch 7:   6%|▌         | 221/3683 [00:19<04:57, 11.64it/s, avg_loss=0.658, loss=0.658, step=26002]

[Step 26000/40000] Avg Loss: 0.6399


Training Epoch 7:   9%|▊         | 321/3683 [00:27<04:48, 11.64it/s, avg_loss=0.658, loss=0.658, step=26102]

[Step 26100/40000] Avg Loss: 0.6376


Training Epoch 7:  11%|█▏        | 421/3683 [00:36<04:40, 11.63it/s, avg_loss=0.638, loss=0.658, step=26202]

[Step 26200/40000] Avg Loss: 0.6388


Training Epoch 7:  14%|█▍        | 521/3683 [00:44<04:32, 11.61it/s, avg_loss=0.619, loss=0.619, step=26302]

[Step 26300/40000] Avg Loss: 0.6361


Training Epoch 7:  17%|█▋        | 621/3683 [00:53<04:23, 11.62it/s, avg_loss=0.658, loss=0.658, step=26402]

[Step 26400/40000] Avg Loss: 0.6426


Training Epoch 7:  20%|█▉        | 721/3683 [01:02<04:14, 11.65it/s, avg_loss=0.638, loss=0.658, step=26502]

[Step 26500/40000] Avg Loss: 0.6368


Training Epoch 7:  22%|██▏       | 821/3683 [01:10<04:05, 11.64it/s, avg_loss=0.677, loss=0.696, step=26602]

[Step 26600/40000] Avg Loss: 0.6430


Training Epoch 7:  25%|██▌       | 921/3683 [01:19<03:58, 11.60it/s, avg_loss=0.619, loss=0.619, step=26702]

[Step 26700/40000] Avg Loss: 0.6399


Training Epoch 7:  28%|██▊       | 1021/3683 [01:27<03:49, 11.61it/s, avg_loss=0.638, loss=0.658, step=26802]

[Step 26800/40000] Avg Loss: 0.6388


Training Epoch 7:  30%|███       | 1121/3683 [01:36<03:40, 11.63it/s, avg_loss=0.638, loss=0.658, step=26902]

[Step 26900/40000] Avg Loss: 0.6403


Training Epoch 7:  33%|███▎      | 1221/3683 [01:45<03:31, 11.62it/s, avg_loss=0.658, loss=0.619, step=27002]

[Step 27000/40000] Avg Loss: 0.6422


Training Epoch 7:  36%|███▌      | 1321/3683 [01:53<03:23, 11.62it/s, avg_loss=0.638, loss=0.619, step=27102]

[Step 27100/40000] Avg Loss: 0.6422


Training Epoch 7:  39%|███▊      | 1421/3683 [02:02<03:14, 11.64it/s, avg_loss=0.619, loss=0.619, step=27202]

[Step 27200/40000] Avg Loss: 0.6361


Training Epoch 7:  41%|████▏     | 1521/3683 [02:10<03:05, 11.64it/s, avg_loss=0.638, loss=0.619, step=27302]

[Step 27300/40000] Avg Loss: 0.6407


Training Epoch 7:  44%|████▍     | 1621/3683 [02:19<02:57, 11.62it/s, avg_loss=0.619, loss=0.619, step=27402]

[Step 27400/40000] Avg Loss: 0.6376


Training Epoch 7:  47%|████▋     | 1721/3683 [02:28<02:48, 11.62it/s, avg_loss=0.638, loss=0.619, step=27502]

[Step 27500/40000] Avg Loss: 0.6376


Training Epoch 7:  49%|████▉     | 1821/3683 [02:36<02:40, 11.61it/s, avg_loss=0.658, loss=0.658, step=27602]

[Step 27600/40000] Avg Loss: 0.6403


Training Epoch 7:  52%|█████▏    | 1921/3683 [02:45<02:31, 11.65it/s, avg_loss=0.619, loss=0.619, step=27702]

[Step 27700/40000] Avg Loss: 0.6438


Training Epoch 7:  55%|█████▍    | 2021/3683 [02:53<02:23, 11.60it/s, avg_loss=0.677, loss=0.696, step=27802]

[Step 27800/40000] Avg Loss: 0.6403


Training Epoch 7:  58%|█████▊    | 2121/3683 [03:02<02:14, 11.63it/s, avg_loss=0.658, loss=0.658, step=27902]

[Step 27900/40000] Avg Loss: 0.6388


Training Epoch 7:  60%|██████    | 2221/3683 [03:11<02:05, 11.63it/s, avg_loss=0.658, loss=0.658, step=28002]

[Step 28000/40000] Avg Loss: 0.6403


Training Epoch 7:  63%|██████▎   | 2321/3683 [03:19<01:57, 11.63it/s, avg_loss=0.658, loss=0.658, step=28102]

[Step 28100/40000] Avg Loss: 0.6368


Training Epoch 7:  66%|██████▌   | 2421/3683 [03:28<01:48, 11.61it/s, avg_loss=0.677, loss=0.619, step=28202]

[Step 28200/40000] Avg Loss: 0.6442


Training Epoch 7:  68%|██████▊   | 2521/3683 [03:37<01:39, 11.62it/s, avg_loss=0.619, loss=0.619, step=28302]

[Step 28300/40000] Avg Loss: 0.6349


Training Epoch 7:  71%|███████   | 2621/3683 [03:45<01:31, 11.62it/s, avg_loss=0.638, loss=0.619, step=28402]

[Step 28400/40000] Avg Loss: 0.6380


Training Epoch 7:  74%|███████▍  | 2721/3683 [03:54<01:22, 11.62it/s, avg_loss=0.619, loss=0.619, step=28502]

[Step 28500/40000] Avg Loss: 0.6341


Training Epoch 7:  77%|███████▋  | 2821/3683 [04:02<01:14, 11.63it/s, avg_loss=0.619, loss=0.619, step=28602]

[Step 28600/40000] Avg Loss: 0.6311


Training Epoch 7:  79%|███████▉  | 2921/3683 [04:11<01:05, 11.64it/s, avg_loss=0.619, loss=0.619, step=28702]

[Step 28700/40000] Avg Loss: 0.6418


Training Epoch 7:  82%|████████▏ | 3021/3683 [04:20<00:57, 11.50it/s, avg_loss=0.619, loss=0.619, step=28802]

[Step 28800/40000] Avg Loss: 0.6399


Training Epoch 7:  85%|████████▍ | 3121/3683 [04:28<00:48, 11.64it/s, avg_loss=0.638, loss=0.619, step=28902]

[Step 28900/40000] Avg Loss: 0.6349


Training Epoch 7:  87%|████████▋ | 3221/3683 [04:37<00:39, 11.61it/s, avg_loss=0.638, loss=0.658, step=29002]

[Step 29000/40000] Avg Loss: 0.6407


Training Epoch 7:  90%|█████████ | 3321/3683 [04:45<00:31, 11.64it/s, avg_loss=0.658, loss=0.619, step=29102]

[Step 29100/40000] Avg Loss: 0.6399


Training Epoch 7:  93%|█████████▎| 3421/3683 [04:54<00:22, 11.63it/s, avg_loss=0.619, loss=0.619, step=29202]

[Step 29200/40000] Avg Loss: 0.6395


Training Epoch 7:  96%|█████████▌| 3521/3683 [05:03<00:13, 11.61it/s, avg_loss=0.619, loss=0.619, step=29302]

[Step 29300/40000] Avg Loss: 0.6388


Training Epoch 7:  98%|█████████▊| 3621/3683 [05:11<00:05, 11.63it/s, avg_loss=0.638, loss=0.619, step=29402]

[Step 29400/40000] Avg Loss: 0.6384


Training Epoch 7: 100%|██████████| 3683/3683 [05:17<00:00, 11.62it/s, avg_loss=0.643, loss=0.619, step=29464]


Epoch 8


Training Epoch 8:   1%|          | 37/3683 [00:03<05:14, 11.58it/s, avg_loss=0.638, loss=0.619, step=29502]

[Step 29500/40000] Avg Loss: 0.6438


Training Epoch 8:   4%|▎         | 137/3683 [00:12<05:05, 11.59it/s, avg_loss=0.638, loss=0.619, step=29602]

[Step 29600/40000] Avg Loss: 0.6411


Training Epoch 8:   6%|▋         | 237/3683 [00:20<04:56, 11.62it/s, avg_loss=0.619, loss=0.619, step=29702]

[Step 29700/40000] Avg Loss: 0.6365


Training Epoch 8:   9%|▉         | 337/3683 [00:29<04:48, 11.60it/s, avg_loss=0.677, loss=0.696, step=29802]

[Step 29800/40000] Avg Loss: 0.6384


Training Epoch 8:  12%|█▏        | 437/3683 [00:37<04:39, 11.63it/s, avg_loss=0.638, loss=0.658, step=29902]

[Step 29900/40000] Avg Loss: 0.6411


Training Epoch 8:  15%|█▍        | 537/3683 [00:46<04:30, 11.62it/s, avg_loss=0.677, loss=0.658, step=3e+4] 

[Step 30000/40000] Avg Loss: 0.6415


Training Epoch 8:  17%|█▋        | 637/3683 [00:55<04:22, 11.59it/s, avg_loss=0.638, loss=0.658, step=30102]

[Step 30100/40000] Avg Loss: 0.6407


Training Epoch 8:  20%|██        | 737/3683 [01:03<04:13, 11.61it/s, avg_loss=0.658, loss=0.658, step=30202]

[Step 30200/40000] Avg Loss: 0.6365


Training Epoch 8:  23%|██▎       | 837/3683 [01:12<04:05, 11.61it/s, avg_loss=0.658, loss=0.696, step=30302]

[Step 30300/40000] Avg Loss: 0.6388


Training Epoch 8:  25%|██▌       | 937/3683 [01:20<03:56, 11.62it/s, avg_loss=0.638, loss=0.658, step=30402]

[Step 30400/40000] Avg Loss: 0.6365


Training Epoch 8:  28%|██▊       | 1037/3683 [01:29<03:47, 11.61it/s, avg_loss=0.638, loss=0.658, step=30502]

[Step 30500/40000] Avg Loss: 0.6392


Training Epoch 8:  31%|███       | 1137/3683 [01:38<03:39, 11.59it/s, avg_loss=0.638, loss=0.658, step=30602]

[Step 30600/40000] Avg Loss: 0.6422


Training Epoch 8:  34%|███▎      | 1237/3683 [01:46<03:30, 11.63it/s, avg_loss=0.638, loss=0.658, step=30702]

[Step 30700/40000] Avg Loss: 0.6399


Training Epoch 8:  36%|███▋      | 1337/3683 [01:55<03:22, 11.61it/s, avg_loss=0.619, loss=0.619, step=30802]

[Step 30800/40000] Avg Loss: 0.6388


Training Epoch 8:  39%|███▉      | 1437/3683 [02:04<03:13, 11.61it/s, avg_loss=0.638, loss=0.658, step=30902]

[Step 30900/40000] Avg Loss: 0.6345


Training Epoch 8:  42%|████▏     | 1537/3683 [02:12<03:04, 11.62it/s, avg_loss=0.638, loss=0.619, step=31002]

[Step 31000/40000] Avg Loss: 0.6422


Training Epoch 8:  44%|████▍     | 1637/3683 [02:21<02:56, 11.59it/s, avg_loss=0.658, loss=0.658, step=31102]

[Step 31100/40000] Avg Loss: 0.6415


Training Epoch 8:  47%|████▋     | 1737/3683 [02:29<02:47, 11.62it/s, avg_loss=0.658, loss=0.696, step=31202]

[Step 31200/40000] Avg Loss: 0.6372


Training Epoch 8:  50%|████▉     | 1837/3683 [02:38<02:39, 11.60it/s, avg_loss=0.619, loss=0.619, step=31302]

[Step 31300/40000] Avg Loss: 0.6411


Training Epoch 8:  53%|█████▎    | 1937/3683 [02:47<02:30, 11.61it/s, avg_loss=0.619, loss=0.619, step=31402]

[Step 31400/40000] Avg Loss: 0.6349


Training Epoch 8:  55%|█████▌    | 2037/3683 [02:55<02:21, 11.61it/s, avg_loss=0.638, loss=0.619, step=31502]

[Step 31500/40000] Avg Loss: 0.6415


Training Epoch 8:  58%|█████▊    | 2137/3683 [03:04<02:13, 11.60it/s, avg_loss=0.658, loss=0.658, step=31602]

[Step 31600/40000] Avg Loss: 0.6384


Training Epoch 8:  61%|██████    | 2237/3683 [03:12<02:04, 11.58it/s, avg_loss=0.638, loss=0.658, step=31702]

[Step 31700/40000] Avg Loss: 0.6392


Training Epoch 8:  63%|██████▎   | 2337/3683 [03:21<01:55, 11.62it/s, avg_loss=0.638, loss=0.619, step=31802]

[Step 31800/40000] Avg Loss: 0.6415


Training Epoch 8:  66%|██████▌   | 2437/3683 [03:30<01:47, 11.60it/s, avg_loss=0.638, loss=0.658, step=31902]

[Step 31900/40000] Avg Loss: 0.6407


Training Epoch 8:  69%|██████▉   | 2537/3683 [03:38<01:38, 11.60it/s, avg_loss=0.619, loss=0.619, step=32002]

[Step 32000/40000] Avg Loss: 0.6418


Training Epoch 8:  72%|███████▏  | 2637/3683 [03:47<01:30, 11.61it/s, avg_loss=0.658, loss=0.658, step=32102]

[Step 32100/40000] Avg Loss: 0.6365


Training Epoch 8:  74%|███████▍  | 2737/3683 [03:56<01:21, 11.58it/s, avg_loss=0.619, loss=0.619, step=32202]

[Step 32200/40000] Avg Loss: 0.6376


Training Epoch 8:  77%|███████▋  | 2837/3683 [04:04<01:12, 11.62it/s, avg_loss=0.677, loss=0.696, step=32302]

[Step 32300/40000] Avg Loss: 0.6392


Training Epoch 8:  80%|███████▉  | 2937/3683 [04:13<01:04, 11.61it/s, avg_loss=0.619, loss=0.619, step=32402]

[Step 32400/40000] Avg Loss: 0.6407


Training Epoch 8:  82%|████████▏ | 3037/3683 [04:21<00:55, 11.61it/s, avg_loss=0.619, loss=0.619, step=32502]

[Step 32500/40000] Avg Loss: 0.6392


Training Epoch 8:  85%|████████▌ | 3137/3683 [04:30<00:47, 11.59it/s, avg_loss=0.638, loss=0.658, step=32602]

[Step 32600/40000] Avg Loss: 0.6384


Training Epoch 8:  88%|████████▊ | 3237/3683 [04:39<00:38, 11.60it/s, avg_loss=0.619, loss=0.619, step=32702]

[Step 32700/40000] Avg Loss: 0.6357


Training Epoch 8:  91%|█████████ | 3337/3683 [04:47<00:29, 11.62it/s, avg_loss=0.638, loss=0.658, step=32802]

[Step 32800/40000] Avg Loss: 0.6415


Training Epoch 8:  93%|█████████▎| 3437/3683 [04:56<00:21, 11.61it/s, avg_loss=0.619, loss=0.619, step=32902]

[Step 32900/40000] Avg Loss: 0.6399


Training Epoch 8:  96%|█████████▌| 3537/3683 [05:05<00:12, 11.58it/s, avg_loss=0.619, loss=0.619, step=33002]

[Step 33000/40000] Avg Loss: 0.6392


Training Epoch 8:  99%|█████████▉| 3637/3683 [05:13<00:03, 11.60it/s, avg_loss=0.638, loss=0.619, step=33102]

[Step 33100/40000] Avg Loss: 0.6368


Training Epoch 8: 100%|██████████| 3683/3683 [05:17<00:00, 11.60it/s, avg_loss=0.632, loss=0.619, step=33147]


Epoch 9


Training Epoch 9:   1%|▏         | 55/3683 [00:04<05:12, 11.61it/s, avg_loss=0.658, loss=0.658, step=33202]

[Step 33200/40000] Avg Loss: 0.6361


Training Epoch 9:   4%|▍         | 155/3683 [00:13<05:04, 11.59it/s, avg_loss=0.638, loss=0.658, step=33302]

[Step 33300/40000] Avg Loss: 0.6376


Training Epoch 9:   7%|▋         | 255/3683 [00:22<04:55, 11.61it/s, avg_loss=0.638, loss=0.658, step=33402]

[Step 33400/40000] Avg Loss: 0.6372


Training Epoch 9:  10%|▉         | 355/3683 [00:30<04:46, 11.61it/s, avg_loss=0.638, loss=0.619, step=33502]

[Step 33500/40000] Avg Loss: 0.6415


Training Epoch 9:  12%|█▏        | 455/3683 [00:39<04:38, 11.60it/s, avg_loss=0.619, loss=0.619, step=33602]

[Step 33600/40000] Avg Loss: 0.6411


Training Epoch 9:  15%|█▌        | 555/3683 [00:47<04:29, 11.61it/s, avg_loss=0.619, loss=0.619, step=33702]

[Step 33700/40000] Avg Loss: 0.6384


Training Epoch 9:  18%|█▊        | 655/3683 [00:56<04:21, 11.59it/s, avg_loss=0.638, loss=0.658, step=33802]

[Step 33800/40000] Avg Loss: 0.6388


Training Epoch 9:  20%|██        | 755/3683 [01:05<04:13, 11.56it/s, avg_loss=0.619, loss=0.619, step=33902]

[Step 33900/40000] Avg Loss: 0.6434


Training Epoch 9:  23%|██▎       | 855/3683 [01:13<04:03, 11.62it/s, avg_loss=0.619, loss=0.619, step=34002]

[Step 34000/40000] Avg Loss: 0.6442


Training Epoch 9:  26%|██▌       | 955/3683 [01:22<03:55, 11.61it/s, avg_loss=0.638, loss=0.619, step=34102]

[Step 34100/40000] Avg Loss: 0.6361


Training Epoch 9:  29%|██▊       | 1055/3683 [01:31<03:45, 11.63it/s, avg_loss=0.619, loss=0.619, step=34202]

[Step 34200/40000] Avg Loss: 0.6384


Training Epoch 9:  31%|███▏      | 1155/3683 [01:39<03:37, 11.61it/s, avg_loss=0.638, loss=0.658, step=34302]

[Step 34300/40000] Avg Loss: 0.6395


Training Epoch 9:  34%|███▍      | 1255/3683 [01:48<03:29, 11.60it/s, avg_loss=0.677, loss=0.735, step=34402]

[Step 34400/40000] Avg Loss: 0.6407


Training Epoch 9:  37%|███▋      | 1355/3683 [01:56<03:20, 11.61it/s, avg_loss=0.677, loss=0.658, step=34502]

[Step 34500/40000] Avg Loss: 0.6399


Training Epoch 9:  40%|███▉      | 1455/3683 [02:05<03:12, 11.57it/s, avg_loss=0.638, loss=0.619, step=34602]

[Step 34600/40000] Avg Loss: 0.6415


Training Epoch 9:  42%|████▏     | 1555/3683 [02:14<03:03, 11.62it/s, avg_loss=0.658, loss=0.619, step=34702]

[Step 34700/40000] Avg Loss: 0.6395


Training Epoch 9:  45%|████▍     | 1655/3683 [02:22<02:54, 11.61it/s, avg_loss=0.619, loss=0.619, step=34802]

[Step 34800/40000] Avg Loss: 0.6392


Training Epoch 9:  48%|████▊     | 1755/3683 [02:31<02:46, 11.58it/s, avg_loss=0.638, loss=0.619, step=34902]

[Step 34900/40000] Avg Loss: 0.6395


Training Epoch 9:  50%|█████     | 1855/3683 [02:40<02:37, 11.59it/s, avg_loss=0.619, loss=0.619, step=35002]

[Step 35000/40000] Avg Loss: 0.6349


Training Epoch 9:  53%|█████▎    | 1955/3683 [02:48<02:28, 11.60it/s, avg_loss=0.638, loss=0.619, step=35102]

[Step 35100/40000] Avg Loss: 0.6388


Training Epoch 9:  56%|█████▌    | 2055/3683 [02:57<02:20, 11.57it/s, avg_loss=0.658, loss=0.658, step=35202]

[Step 35200/40000] Avg Loss: 0.6368


Training Epoch 9:  59%|█████▊    | 2155/3683 [03:05<02:11, 11.60it/s, avg_loss=0.658, loss=0.619, step=35302]

[Step 35300/40000] Avg Loss: 0.6372


Training Epoch 9:  61%|██████    | 2255/3683 [03:14<02:03, 11.61it/s, avg_loss=0.658, loss=0.658, step=35402]

[Step 35400/40000] Avg Loss: 0.6399


Training Epoch 9:  64%|██████▍   | 2355/3683 [03:23<01:54, 11.61it/s, avg_loss=0.619, loss=0.619, step=35502]

[Step 35500/40000] Avg Loss: 0.6353


Training Epoch 9:  67%|██████▋   | 2455/3683 [03:31<01:45, 11.60it/s, avg_loss=0.619, loss=0.619, step=35602]

[Step 35600/40000] Avg Loss: 0.6368


Training Epoch 9:  69%|██████▉   | 2555/3683 [03:40<01:37, 11.59it/s, avg_loss=0.638, loss=0.658, step=35702]

[Step 35700/40000] Avg Loss: 0.6380


Training Epoch 9:  72%|███████▏  | 2655/3683 [03:49<01:28, 11.60it/s, avg_loss=0.638, loss=0.658, step=35802]

[Step 35800/40000] Avg Loss: 0.6411


Training Epoch 9:  75%|███████▍  | 2755/3683 [03:57<01:19, 11.60it/s, avg_loss=0.638, loss=0.658, step=35902]

[Step 35900/40000] Avg Loss: 0.6392


Training Epoch 9:  78%|███████▊  | 2855/3683 [04:06<01:11, 11.62it/s, avg_loss=0.619, loss=0.619, step=36002]

[Step 36000/40000] Avg Loss: 0.6418


Training Epoch 9:  80%|████████  | 2955/3683 [04:14<01:02, 11.61it/s, avg_loss=0.658, loss=0.658, step=36102]

[Step 36100/40000] Avg Loss: 0.6403


Training Epoch 9:  83%|████████▎ | 3055/3683 [04:23<00:54, 11.58it/s, avg_loss=0.658, loss=0.696, step=36202]

[Step 36200/40000] Avg Loss: 0.6411


Training Epoch 9:  86%|████████▌ | 3155/3683 [04:32<00:45, 11.60it/s, avg_loss=0.638, loss=0.619, step=36302]

[Step 36300/40000] Avg Loss: 0.6392


Training Epoch 9:  88%|████████▊ | 3255/3683 [04:40<00:36, 11.60it/s, avg_loss=0.619, loss=0.619, step=36402]

[Step 36400/40000] Avg Loss: 0.6372


Training Epoch 9:  91%|█████████ | 3355/3683 [04:49<00:28, 11.58it/s, avg_loss=0.658, loss=0.658, step=36502]

[Step 36500/40000] Avg Loss: 0.6341


Training Epoch 9:  94%|█████████▍| 3455/3683 [04:58<00:19, 11.62it/s, avg_loss=0.658, loss=0.658, step=36602]

[Step 36600/40000] Avg Loss: 0.6434


Training Epoch 9:  97%|█████████▋| 3555/3683 [05:06<00:11, 11.58it/s, avg_loss=0.638, loss=0.658, step=36702]

[Step 36700/40000] Avg Loss: 0.6372


Training Epoch 9:  99%|█████████▉| 3655/3683 [05:15<00:02, 11.60it/s, avg_loss=0.638, loss=0.658, step=36802]

[Step 36800/40000] Avg Loss: 0.6392


Training Epoch 9: 100%|██████████| 3683/3683 [05:17<00:00, 11.59it/s, avg_loss=0.641, loss=0.658, step=36830]


Epoch 10


Training Epoch 10:   2%|▏         | 71/3683 [00:06<05:10, 11.64it/s, avg_loss=0.658, loss=0.658, step=36902]

[Step 36900/40000] Avg Loss: 0.6426


Training Epoch 10:   5%|▍         | 171/3683 [00:14<05:02, 11.63it/s, avg_loss=0.658, loss=0.658, step=37002]

[Step 37000/40000] Avg Loss: 0.6388


Training Epoch 10:   7%|▋         | 271/3683 [00:23<04:53, 11.62it/s, avg_loss=0.638, loss=0.619, step=37102]

[Step 37100/40000] Avg Loss: 0.6388


Training Epoch 10:  10%|█         | 371/3683 [00:32<04:44, 11.64it/s, avg_loss=0.677, loss=0.658, step=37202]

[Step 37200/40000] Avg Loss: 0.6407


Training Epoch 10:  13%|█▎        | 471/3683 [00:40<04:36, 11.62it/s, avg_loss=0.619, loss=0.619, step=37302]

[Step 37300/40000] Avg Loss: 0.6418


Training Epoch 10:  16%|█▌        | 571/3683 [00:49<04:28, 11.57it/s, avg_loss=0.638, loss=0.658, step=37402]

[Step 37400/40000] Avg Loss: 0.6372


Training Epoch 10:  18%|█▊        | 671/3683 [00:57<04:18, 11.63it/s, avg_loss=0.638, loss=0.619, step=37502]

[Step 37500/40000] Avg Loss: 0.6384


Training Epoch 10:  21%|██        | 771/3683 [01:06<04:10, 11.63it/s, avg_loss=0.638, loss=0.658, step=37602]

[Step 37600/40000] Avg Loss: 0.6422


Training Epoch 10:  24%|██▎       | 871/3683 [01:15<04:01, 11.63it/s, avg_loss=0.638, loss=0.658, step=37702]

[Step 37700/40000] Avg Loss: 0.6403


Training Epoch 10:  26%|██▋       | 971/3683 [01:23<03:53, 11.62it/s, avg_loss=0.658, loss=0.619, step=37802]

[Step 37800/40000] Avg Loss: 0.6368


Training Epoch 10:  29%|██▉       | 1071/3683 [01:32<03:44, 11.63it/s, avg_loss=0.638, loss=0.619, step=37902]

[Step 37900/40000] Avg Loss: 0.6365


Training Epoch 10:  32%|███▏      | 1171/3683 [01:40<03:35, 11.63it/s, avg_loss=0.638, loss=0.619, step=38002]

[Step 38000/40000] Avg Loss: 0.6361


Training Epoch 10:  35%|███▍      | 1271/3683 [01:49<03:27, 11.63it/s, avg_loss=0.638, loss=0.619, step=38102]

[Step 38100/40000] Avg Loss: 0.6395


Training Epoch 10:  37%|███▋      | 1371/3683 [01:58<03:18, 11.64it/s, avg_loss=0.658, loss=0.696, step=38202]

[Step 38200/40000] Avg Loss: 0.6357


Training Epoch 10:  40%|███▉      | 1471/3683 [02:06<03:10, 11.62it/s, avg_loss=0.638, loss=0.619, step=38302]

[Step 38300/40000] Avg Loss: 0.6380


Training Epoch 10:  43%|████▎     | 1571/3683 [02:15<03:01, 11.64it/s, avg_loss=0.619, loss=0.619, step=38402]

[Step 38400/40000] Avg Loss: 0.6422


Training Epoch 10:  45%|████▌     | 1671/3683 [02:24<02:53, 11.63it/s, avg_loss=0.619, loss=0.619, step=38502]

[Step 38500/40000] Avg Loss: 0.6418


Training Epoch 10:  48%|████▊     | 1771/3683 [02:32<02:44, 11.64it/s, avg_loss=0.638, loss=0.619, step=38602]

[Step 38600/40000] Avg Loss: 0.6380


Training Epoch 10:  51%|█████     | 1871/3683 [02:41<02:35, 11.62it/s, avg_loss=0.619, loss=0.619, step=38702]

[Step 38700/40000] Avg Loss: 0.6376


Training Epoch 10:  54%|█████▎    | 1971/3683 [02:49<02:27, 11.60it/s, avg_loss=0.638, loss=0.658, step=38802]

[Step 38800/40000] Avg Loss: 0.6372


Training Epoch 10:  56%|█████▌    | 2071/3683 [02:58<02:18, 11.63it/s, avg_loss=0.658, loss=0.658, step=38902]

[Step 38900/40000] Avg Loss: 0.6430


Training Epoch 10:  59%|█████▉    | 2171/3683 [03:07<02:10, 11.62it/s, avg_loss=0.677, loss=0.619, step=39002]

[Step 39000/40000] Avg Loss: 0.6407


Training Epoch 10:  62%|██████▏   | 2271/3683 [03:15<02:01, 11.59it/s, avg_loss=0.619, loss=0.619, step=39102]

[Step 39100/40000] Avg Loss: 0.6399


Training Epoch 10:  64%|██████▍   | 2371/3683 [03:24<01:52, 11.62it/s, avg_loss=0.638, loss=0.658, step=39202]

[Step 39200/40000] Avg Loss: 0.6407


Training Epoch 10:  67%|██████▋   | 2471/3683 [03:32<01:44, 11.64it/s, avg_loss=0.658, loss=0.658, step=39302]

[Step 39300/40000] Avg Loss: 0.6376


Training Epoch 10:  70%|██████▉   | 2571/3683 [03:41<01:35, 11.60it/s, avg_loss=0.638, loss=0.658, step=39402]

[Step 39400/40000] Avg Loss: 0.6407


Training Epoch 10:  73%|███████▎  | 2671/3683 [03:50<01:27, 11.60it/s, avg_loss=0.658, loss=0.658, step=39502]

[Step 39500/40000] Avg Loss: 0.6372


Training Epoch 10:  75%|███████▌  | 2771/3683 [03:58<01:18, 11.61it/s, avg_loss=0.658, loss=0.696, step=39602]

[Step 39600/40000] Avg Loss: 0.6430


Training Epoch 10:  78%|███████▊  | 2871/3683 [04:07<01:10, 11.60it/s, avg_loss=0.619, loss=0.619, step=39702]

[Step 39700/40000] Avg Loss: 0.6376


Training Epoch 10:  81%|████████  | 2971/3683 [04:15<01:01, 11.64it/s, avg_loss=0.619, loss=0.619, step=39802]

[Step 39800/40000] Avg Loss: 0.6407


Training Epoch 10:  83%|████████▎ | 3071/3683 [04:24<00:52, 11.63it/s, avg_loss=0.658, loss=0.619, step=39902]

[Step 39900/40000] Avg Loss: 0.6395


Training Epoch 10:  86%|████████▌ | 3170/3683 [04:32<00:44, 11.61it/s, avg_loss=63.8, loss=0.658, step=4e+4]  


[Step 40000/40000] Avg Loss: 0.6384
Model saved to 'llama_su_pretrained.pt'


In [16]:
torch.save(model.bert.state_dict(), "llama_su_pretrained.pt")

AttributeError: 'SepClassifier' object has no attribute 'bert'

Validation & early-stop (optional)

- Use the same DataLoader/loop on su_ds["validation"], compute average BCE loss; if it plateaus you can stop earlier than 5 k steps (what the authors mean by “until train loss converged”).

# Create Summarization Dataset

ขั้นตอนการทำ preprocess
1. โหลดชุดข้อมูล SAMSum
2. ทำ preprocessing:
    - แทนชื่อ speaker ด้วย [S1]–[S10]
    - เติม [SEP] ท้ายทุกประโยค
    - ใช้ tokenizer เดิมจาก pretraining (tokenizer_samsum_su)
    - truncate/pad ความยาวที่ max_length = 512
3. แปลงให้อยู่ในรูปแบบที่พร้อมใช้สำหรับ Seq2SeqTrainer
4. Save เป็นไฟล์ .pt หรือ DatasetDict ที่พร้อมใช้งาน

Load SAMSum Dataset

In [10]:
raw_ds: DatasetDict = load_dataset("samsum")
print({k: len(v) for k, v in raw_ds.items()})

{'train': 14732, 'test': 819, 'validation': 818}


Load Pretrained Tokenizer (same as used during pretraining)

In [11]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer_samsum_su")
MAX_LEN = 512

Speaker Normalization Helpers

In [12]:
SPEAKER_RE = re.compile(r"^([^:]+):\s*(.*)$")

def map_speakers(dialogue: str, max_speakers: int = 10) -> Tuple[str, Dict[str, str]]:
    """
    Replace speaker names with generic [S1], [S2], ... tokens.
    """
    speaker_map, next_id = {}, 1
    new_lines = []
    for line in dialogue.split("\n"):
        m = SPEAKER_RE.match(line)
        if not m:
            new_lines.append(line)
            continue
        name, utt = m.groups()
        if name not in speaker_map:
            if next_id > max_speakers:
                name_token = "[SUNK]"
            else:
                name_token = f"[S{next_id}]"
                speaker_map[name] = name_token
                next_id += 1
        name_token = speaker_map.get(name, "[SUNK]")
        new_lines.append(f"{name_token}: {utt}")
    return "\n".join(new_lines), speaker_map

def add_sep_every_utt(dialogue: str) -> str:
    lines = [l + " [SEP]" for l in dialogue.split("\n") if l.strip()]
    return " ".join(lines)


Preprocessing Function

In [13]:
def preprocess_fn(example):
    normed_dialogue, _ = map_speakers(example["dialogue"])
    sep_dialogue = add_sep_every_utt(normed_dialogue)

    inputs = tokenizer(
        sep_dialogue,
        truncation=True,
        padding='max_length',
        max_length=MAX_LEN,
    )
    targets = tokenizer(
        example["summary"],
        truncation=True,
        padding='max_length',
        max_length=MAX_LEN,
    )

    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": targets["input_ids"]
    }

Apply Preprocessing

In [14]:
tokenized_ds = raw_ds.map(preprocess_fn, batched=False)
tokenized_ds.save_to_disk("samsum_finetune_ready")
print("Preprocessed dataset saved to 'samsum_finetune_ready'")

Map: 100%|██████████| 14732/14732 [00:08<00:00, 1769.02 examples/s]
Map: 100%|██████████| 819/819 [00:00<00:00, 1693.37 examples/s]
Map: 100%|██████████| 818/818 [00:00<00:00, 1804.79 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 14732/14732 [00:00<00:00, 327887.96 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 819/819 [00:00<00:00, 172195.85 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 818/818 [00:00<00:00, 167281.36 examples/s]

Preprocessed dataset saved to 'samsum_finetune_ready'





In [None]:
# MAX_LEN = 512  # paper setting

# def preprocess_example(example, split):
#     # a) replace speakers & add SEP (same as pretraining)
#     dlg, _ = map_speakers(example["dialogue"])  # แปลงชื่อให้เป็น token สั้น ๆ เช่น <USR1>
#     dlg = add_sep_every_utt(dlg)                # เพิ่ม [SEP] ทุกท้ายประโยค

#     # b) tokenize dialogue input
#     enc = tok_base(dlg,
#               truncation=True,
#               max_length=MAX_LEN,
#               padding="max_length")

#     # c) tok_baseenize target summary
#     with tok_base.as_target_tokenizer():
#         summary = example["summary"]
#         summary_enc = tok_base(summary,
#                           truncation=True,
#                           max_length=MAX_LEN,
#                           padding="max_length")
    
#     # d) pack input and label
#     enc["labels"] = summary_enc["input_ids"]
#     return enc

# # สร้าง dataset ใหม่สำหรับ fine-tune
# finetune_ds = DatasetDict()
# for split in ["train", "validation", "test"]:
#     finetune_ds[split] = raw_ds[split].map(
#         preprocess_example,
#         fn_kwargs={"split": split},
#         remove_columns=raw_ds[split].column_names,
#         desc=f"Building Fine-tuning {split}"
#     )

# finetune_ds.save_to_disk("data/samsum_finetune")
# print(finetune_ds)

Building Fine-tuning train: 100%|██████████| 14732/14732 [00:09<00:00, 1561.26 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 14732/14732 [00:00<00:00, 183818.74 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 818/818 [00:00<00:00, 129347.43 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 819/819 [00:00<00:00, 126710.99 examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 14732
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 818
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 819
    })
})





# Fine-tuning 


**เทียบกับ Paper**

| **Parameter**       | **Code**                          | **Paper (Section 3.2)**          | **เหมือน / ไม่เหมือน**        |
| ------------------- | --------------------------------- | -------------------------------- | --------------       |
| Model               | BERT2BERT (EncoderDecoderModel)   | BERT2BERT                        | เหมือน                 |
| Tokenizer           | bert-base-uncased + custom tokens | ใช้ tokenizer ดัดแปลง              | เหมือน              |
| Batch Size          | 8                                 | **16 (per step)**                | ไม่เหมือน → เล็กกว่า  |
| Epochs              | 3                                 | 3                                | เหมือน              |
| Learning Rate       | 5e-5                              | **3e-5**                         | ไม่เหมือน → สูงกว่า   |
| Warmup Steps        | 500                               | ใช้ scheduler (แต่ไม่ระบุ exact)     | เหมือน (สมเหตุสมผล) |
| Max Length (input)  | 512                               | 512                              | เหมือน              |
| Max Length (output) | 128                               | 128                              | เหมือน              |
| Beam Search         | 4                                 | 4                                | เหมือน              |

---


In [None]:
import torch
from transformers import (
    BertTokenizerFast,
    EncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
from datasets import load_from_disk

# ------------------------------
# Load processed dataset & tokenizer
# ------------------------------
dataset = load_from_disk("data/samsum_finetune_ready")
tokenizer = BertTokenizerFast.from_pretrained("tokenizer_samsum_su")

# ------------------------------
# Load pretrained EncoderDecoderModel
# ------------------------------
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-uncased", "bert-base-uncased"
)
model.encoder.resize_token_embeddings(len(tokenizer))
model.decoder.resize_token_embeddings(len(tokenizer))

# Load your pretrained encoder weights
model.encoder.load_state_dict(torch.load("bert_su_pretrained.pt", map_location="cpu"))
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.encoder.vocab_size
model.config.max_length = 128
model.config.num_beams = 4

# ------------------------------
# Define training arguments
# ------------------------------
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    evaluation_strategy="steps",
    logging_steps=500,
    save_steps=1000,
    num_train_epochs=3,
    learning_rate=5e-5,
    warmup_steps=500,
    fp16=torch.cuda.is_available(),
    save_total_limit=2,
)

# ------------------------------
# Data Collator & Trainer
# ------------------------------
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# ------------------------------
# Start Training
# ------------------------------
trainer.train()
model.save_pretrained("bert_samsum_finetuned")
tokenizer.save_pretrained("tokenizer_samsum_su_finetune")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.enco

{'loss': 1.4418, 'grad_norm': 0.3646565079689026, 'learning_rate': 4.96e-05, 'epoch': 0.27}


                                                  
  9%|▉         | 501/5526 [01:41<2:22:17,  1.70s/it]

{'eval_loss': 0.22444657981395721, 'eval_runtime': 5.0061, 'eval_samples_per_second': 163.402, 'eval_steps_per_second': 20.575, 'epoch': 0.27}


 18%|█▊        | 1000/5526 [03:17<14:36,  5.17it/s] 

{'loss': 0.2172, 'grad_norm': 0.4135509133338928, 'learning_rate': 4.506565857540788e-05, 'epoch': 0.54}


                                                   


{'eval_loss': 0.19904808700084686, 'eval_runtime': 5.0413, 'eval_samples_per_second': 162.259, 'eval_steps_per_second': 20.431, 'epoch': 0.54}


 27%|██▋       | 1500/5526 [05:00<12:53,  5.21it/s]  

{'loss': 0.2008, 'grad_norm': 0.274005264043808, 'learning_rate': 4.009152407481098e-05, 'epoch': 0.81}


                                                   
 27%|██▋       | 1501/5526 [05:06<1:53:34,  1.69s/it]

{'eval_loss': 0.1895398199558258, 'eval_runtime': 4.9978, 'eval_samples_per_second': 163.672, 'eval_steps_per_second': 20.609, 'epoch': 0.81}


 36%|███▌      | 2000/5526 [06:42<11:18,  5.20it/s]  

{'loss': 0.1864, 'grad_norm': 0.31300443410873413, 'learning_rate': 3.511738957421409e-05, 'epoch': 1.09}


                                                   
 36%|███▌      | 2000/5526 [06:47<11:18,  5.20it/s]

{'eval_loss': 0.18359985947608948, 'eval_runtime': 4.9794, 'eval_samples_per_second': 164.278, 'eval_steps_per_second': 20.685, 'epoch': 1.09}


 45%|████▌     | 2500/5526 [08:25<09:41,  5.21it/s]  

{'loss': 0.17, 'grad_norm': 0.36899664998054504, 'learning_rate': 3.0143255073617192e-05, 'epoch': 1.36}


                                                   
 45%|████▌     | 2501/5526 [08:30<1:25:32,  1.70s/it]

{'eval_loss': 0.17844413220882416, 'eval_runtime': 5.0089, 'eval_samples_per_second': 163.311, 'eval_steps_per_second': 20.564, 'epoch': 1.36}


 54%|█████▍    | 3000/5526 [10:06<08:05,  5.20it/s]  

{'loss': 0.1689, 'grad_norm': 0.43778106570243835, 'learning_rate': 2.5169120573020293e-05, 'epoch': 1.63}


                                                   
 54%|█████▍    | 3000/5526 [10:11<08:05,  5.20it/s]

{'eval_loss': 0.17464406788349152, 'eval_runtime': 4.984, 'eval_samples_per_second': 164.126, 'eval_steps_per_second': 20.666, 'epoch': 1.63}


 63%|██████▎   | 3500/5526 [11:50<06:25,  5.26it/s]  

{'loss': 0.1632, 'grad_norm': 0.3436298966407776, 'learning_rate': 2.01949860724234e-05, 'epoch': 1.9}


                                                   
 63%|██████▎   | 3501/5526 [11:55<56:57,  1.69s/it]

{'eval_loss': 0.17154935002326965, 'eval_runtime': 4.9824, 'eval_samples_per_second': 164.179, 'eval_steps_per_second': 20.673, 'epoch': 1.9}


 72%|███████▏  | 4000/5526 [13:31<04:52,  5.21it/s]

{'loss': 0.1487, 'grad_norm': 0.4152975380420685, 'learning_rate': 1.5220851571826503e-05, 'epoch': 2.17}


                                                   
 72%|███████▏  | 4000/5526 [13:36<04:52,  5.21it/s]

{'eval_loss': 0.17133557796478271, 'eval_runtime': 4.98, 'eval_samples_per_second': 164.259, 'eval_steps_per_second': 20.683, 'epoch': 2.17}


 81%|████████▏ | 4500/5526 [15:14<03:17,  5.21it/s]  

{'loss': 0.1356, 'grad_norm': 0.35845550894737244, 'learning_rate': 1.0246717071229607e-05, 'epoch': 2.44}


                                                   
 81%|████████▏ | 4501/5526 [15:19<28:52,  1.69s/it]

{'eval_loss': 0.1692614108324051, 'eval_runtime': 4.9857, 'eval_samples_per_second': 164.068, 'eval_steps_per_second': 20.659, 'epoch': 2.44}


 90%|█████████ | 5000/5526 [16:55<01:41,  5.20it/s]

{'loss': 0.1402, 'grad_norm': 0.3714558780193329, 'learning_rate': 5.27258257063271e-06, 'epoch': 2.71}


                                                   
 90%|█████████ | 5000/5526 [17:00<01:41,  5.20it/s]

{'eval_loss': 0.1681322604417801, 'eval_runtime': 4.9793, 'eval_samples_per_second': 164.281, 'eval_steps_per_second': 20.686, 'epoch': 2.71}


100%|█████████▉| 5500/5526 [18:38<00:04,  5.21it/s]

{'loss': 0.1387, 'grad_norm': 0.4560355544090271, 'learning_rate': 2.984480700358138e-07, 'epoch': 2.99}


                                                   
100%|█████████▉| 5501/5526 [18:43<00:42,  1.69s/it]

{'eval_loss': 0.1668214201927185, 'eval_runtime': 4.9802, 'eval_samples_per_second': 164.251, 'eval_steps_per_second': 20.682, 'epoch': 2.99}


100%|██████████| 5526/5526 [18:50<00:00,  4.89it/s]


{'train_runtime': 1130.5819, 'train_samples_per_second': 39.091, 'train_steps_per_second': 4.888, 'train_loss': 0.28212739849194124, 'epoch': 3.0}


('tokenizer_samsum_su_finetune/tokenizer_config.json',
 'tokenizer_samsum_su_finetune/special_tokens_map.json',
 'tokenizer_samsum_su_finetune/vocab.txt',
 'tokenizer_samsum_su_finetune/added_tokens.json',
 'tokenizer_samsum_su_finetune/tokenizer.json')

## Evaluation

1. ROUGE (Recall-Oriented Understudy for Gisting Evaluation) ใช้วัดความคล้ายกันระหว่างสรุปที่โมเดลสร้างขึ้นกับสรุปอ้างอิง โดยเน้นไปที่ recall เป็นหลัก
	- ROUGE-1 (R-1) = Unigram overlap (คำเดี่ยว)
	- ROUGE-2 (R-2) = Bigram overlap (คำติดกัน 2 คำ)
	- ROUGE-L (R-L) = ใช้ Longest common subsequence (LCS) ในการวัดความคล้ายเชิงลำดับคำที่ยาวที่สุดที่ปรากฏในทั้งสองสรุป โดยคำนึงถึงลำดับคำด้วย

2. BLEU (Bilingual Evaluation Understudy) เดิมทีใช้ในงานแปลภาษา แต่ถูกประยุกต์ใช้ในงานสรุปข้อความได้เช่นกัน โดย BLEU จะเน้นการวัด precision คือดูว่า คำที่โมเดลสร้าง มีเท่าไรที่ตรงกับสรุปจริง ต่างจาก ROUGE ที่เน้น recall
	- BLEU วัดการทับซ้อนของ n-gram เช่น unigram, bigram, trigram
	- มีการใช้ brevity penalty หากสรุปสั้นกว่าที่ควรจะเป็น

3. BERTScore (BS) ใช้ embedding จากโมเดล BERT หรือ Transformer ตัวอื่น ๆ ในการวัด semantic similarity (ความใกล้เคียงด้านความหมาย) ระหว่างสรุปของโมเดลกับสรุปจริง โดยไม่จำเป็นต้องใช้คำเหมือนกันเป๊ะเหมือนกับ ROUGE หรือ BLEU แต่ BERTScore จะวัดว่าคำหรือวลีมีความหมายใกล้เคียงกันหรือไม่
	- วัดความคล้ายกันของคำใน embedding space เช่น "car" vs "vehicle" ก็ยังถือว่าใกล้เคียง
	- ใช้ precision / recall / F1 score ตามระยะห่างของ vector


In [1]:
from datasets import load_from_disk
from transformers import BertTokenizer, EncoderDecoderModel
from sklearn.metrics import precision_score, recall_score, f1_score
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
import torch
from bert_score import score
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# โหลด dataset
dataset = load_from_disk("data/samsum_finetune_ready")

# โหลดโมเดลและ tokenizer
tokenizer = BertTokenizer.from_pretrained('tokenizer_samsum_su_finetune')
model = EncoderDecoderModel.from_pretrained('bert_samsum_finetuned')

EncoderDecoderModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [None]:
# ย้ายโมเดลและข้อมูลไปยังอุปกรณ์ที่เหมาะสม (GPU หรือ CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# วิเคราะห์ข้อมูลและทำการ summary ด้วย bert_samsum_finetuned
def generate_summary(input_text):
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
    inputs = {key: value.to(device) for key, value in inputs.items()}  # ย้ายข้อมูลไปยัง device
    
    with torch.no_grad():
        output = model.generate(
            inputs['input_ids'], 
            max_length=512, 
            num_beams=4, 
            early_stopping=True,
            decoder_start_token_id=model.config.decoder_start_token_id,  # กำหนดที่นี่
            pad_token_id=model.config.pad_token_id  # กำหนด pad_token_id ถ้าจำเป็น
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [4]:
# 1. ROUGE Score Calculation
def calculate_rouge(predictions, references):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    
    for pred, ref in zip(predictions, references):
        score = scorer.score(ref, pred)
        scores['rouge1'].append(score['rouge1'].fmeasure)
        scores['rouge2'].append(score['rouge2'].fmeasure)
        scores['rougeL'].append(score['rougeL'].fmeasure)
    
    return {key: sum(value)/len(value) for key, value in scores.items()}

# 2. BLEU Score Calculation
def calculate_bleu(predictions, references):
    bleu_scores = []
    for pred, ref in zip(predictions, references):
        pred_tokens = pred.split()
        ref_tokens = [ref.split()]
        bleu_scores.append(sentence_bleu(ref_tokens, pred_tokens))
    return sum(bleu_scores) / len(bleu_scores)

# 3. BERTScore Calculation
def calculate_bertscore(predictions, references):
    P, R, F1 = score(predictions, references, lang='en')
    return P.mean().item(), R.mean().item(), F1.mean().item()

# การทดสอบกับ dataset
def evaluate_model(dataset):
    predictions = []
    references = []
    
    # ใช้ข้อมูลจาก train สำหรับทำนาย และข้อมูลจาก test สำหรับการเปรียบเทียบ
    for i in tqdm(range(len(dataset['test'])), desc="Evaluating", unit="sample"):
        # สร้างสรุปจากโมเดล
        input_text = dataset['train'][i]['dialogue']  # ใช้ 'dialogue' จาก train เพื่อสร้างสรุป
        reference_summary = dataset['test'][i]['summary']  # ใช้ 'summary' จาก test เป็นสรุปจริง
        pred_summary = generate_summary(input_text)  # สร้างสรุปจากโมเดล
        
        predictions.append(pred_summary)
        references.append(reference_summary)
    
    # ROUGE Score
    rouge_scores = calculate_rouge(predictions, references)
    print("ROUGE Scores:", rouge_scores)

    # BLEU Score
    bleu_score = calculate_bleu(predictions, references)
    print("BLEU Score:", bleu_score)

    # BERTScore
    P, R, F1 = calculate_bertscore(predictions, references)
    print("BERTScore - Precision:", P, "Recall:", R, "F1:", F1)

# เรียกใช้งาน
evaluate_model(dataset)

Evaluating: 100%|██████████| 819/819 [39:14<00:00,  2.88s/sample]
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


ROUGE Scores: {'rouge1': 0.08103071050965537, 'rouge2': 0.005501493938462314, 'rougeL': 0.07293283175759344}
BLEU Score: 8.859648156109322e-05


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTScore - Precision: 0.8399370312690735 Recall: 0.8468782305717468 F1: 0.843207597732544
