<a href="https://colab.research.google.com/github/jang-hyunjun/icml_hyunjun/blob/main/vqvae%ED%95%99%EC%8A%B5_%EB%B0%8F_%EC%A0%80%EC%9E%A5_%EC%BD%94%EB%93%9C.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VQVAE ÌïôÏäµ

In [None]:
import os
import re
import gc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from sklearn.cluster import MiniBatchKMeans
from sklearn.model_selection import train_test_split
from google.colab import drive
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, concatenate_datasets

# 1. ÏÑ§Ï†ï Î∞è ÎìúÎùºÏù¥Î∏å
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SAVE_DIR = '/content/drive/MyDrive/VQVAE_Comparison_Final'
os.makedirs(SAVE_DIR, exist_ok=True)

# 2. Îç∞Ïù¥ÌÑ∞ Ï§ÄÎπÑ
print("Loading Data...")
dataset = load_dataset("li2017dailydialog/daily_dialog", revision="refs/convert/parquet", trust_remote_code=True)
combined = concatenate_datasets([dataset["train"], dataset["validation"], dataset["test"]])

sentences = []
for dialog in tqdm(combined["dialog"], desc="Parsing"):
    for utt in dialog:
        for s in re.split(r'(?<=[\.\?\!])\s+', utt.strip()):
            if s and len(s.split()) > 2: sentences.append(s)

# Îç∞Ïù¥ÌÑ∞ Î∂ÑÌï† (Train 90% : Val 10%)
train_sentences, val_sentences = train_test_split(sentences, test_size=0.1, random_state=42)
print(f"Total Sentences: {len(sentences)} | Train: {len(train_sentences)} | Val: {len(val_sentences)}")


class VQVAE_Original_Structure(nn.Module):
    def __init__(
        self,
        input_dim=768,
        hidden_dim=518,
        num_embeddings=8000,
        embedding_dim=64,
        beta=0.25,
        lambda_ent=0.1,
        decay=0.99,
        epsilon=1e-5,
    ):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim  = embedding_dim
        self.beta           = beta
        self.lambda_ent     = lambda_ent
        self.decay          = decay
        self.epsilon        = epsilon

        self.encoder_body = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.LeakyReLU(0.1),
            nn.Linear(hidden_dim, 216), nn.LeakyReLU(0.1),
            nn.Linear(216, 108), nn.LeakyReLU(0.1),
            nn.Linear(108, 64), nn.LeakyReLU(0.1),
            nn.Linear(64, embedding_dim)
        )
        self.encoder_residual = nn.Linear(input_dim, embedding_dim, bias=False)
        self.z_scale = nn.Parameter(torch.tensor(5.0))

        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        self.codebook.weight.data.normal_()

        self.register_buffer('cluster_size', torch.zeros(num_embeddings, dtype=torch.float))
        self.register_buffer('embedding_avg', torch.zeros(num_embeddings, embedding_dim, dtype=torch.float))
        self.register_buffer('codebook_usage', torch.zeros(num_embeddings, dtype=torch.long))

        self.decoder = nn.Linear(embedding_dim, input_dim)

    def forward(self, x):
        z_body = self.encoder_body(x)
        z_res  = self.encoder_residual(x)
        z_e    = (z_body + z_res) * self.z_scale

        dists   = torch.cdist(z_e, self.codebook.weight)
        indices = torch.argmin(dists, dim=1)
        z_q     = self.codebook(indices)

        if self.training:
            uidx, counts = torch.unique(indices, return_counts=True)
            self.codebook_usage[uidx] += counts

            one_hot = F.one_hot(indices, self.num_embeddings).type_as(z_e)
            batch_cluster_size = one_hot.sum(dim=0)
            batch_embed_sum    = one_hot.t() @ z_e

            self.cluster_size.data.mul_(self.decay).add_(batch_cluster_size, alpha=1 - self.decay)
            self.embedding_avg.data.mul_(self.decay).add_(batch_embed_sum, alpha=1 - self.decay)

            n = self.cluster_size.sum()
            cluster_size_norm = (self.cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
            weight_norm = self.embedding_avg / cluster_size_norm.unsqueeze(1)
            self.codebook.weight.data.copy_(weight_norm)

            dead_mask = self.cluster_size < 1e-3
            if dead_mask.any():
                n_dead = dead_mask.sum().item()
                rand_idx = torch.randint(0, z_e.size(0), (n_dead,), device=z_e.device)
                self.codebook.weight.data[dead_mask] = z_e[rand_idx].detach()
                self.embedding_avg.data[dead_mask] = z_e[rand_idx].detach()
                self.cluster_size.data[dead_mask] = 1.0

        z_q_st = z_e + (z_q - z_e).detach()
        x_recon = self.decoder(z_q_st)

        recon_loss  = F.mse_loss(x_recon, x)
        commit_loss = F.mse_loss(z_e, z_q.detach())

        entropy_penalty = torch.tensor(0., device=x.device)
        if self.lambda_ent > 0:
            p = (self.cluster_size + 1e-9) / (self.cluster_size.sum() + 1e-9)
            H = -(p * torch.log2(p)).sum()
            entropy_penalty = -self.lambda_ent * H

        total_loss = recon_loss + self.beta * commit_loss + entropy_penalty

        return {
            'loss': total_loss,
            'recon_loss': recon_loss,
            'indices': indices,
            'x_recon': x_recon
        }


class EmbeddingDataset(Dataset):
    def __init__(self, tokenizer, sentences):
        self.tokenizer = tokenizer
        self.sentences = sentences
    def __len__(self): return len(self.sentences)
    def __getitem__(self, idx): return self.sentences[idx]

def get_collate_fn(tokenizer, base_model, device):
    def collate(batch_text):
        inputs = tokenizer(batch_text, return_tensors='pt', padding=True, truncation=True, max_length=64).to(device)
        with torch.no_grad():
            outputs = base_model(**inputs, output_hidden_states=True)
            last_hidden = outputs.last_hidden_state

        input_ids = inputs.input_ids
        special_mask = (
            (input_ids != tokenizer.cls_token_id) &
            (input_ids != tokenizer.sep_token_id) &
            (input_ids != tokenizer.pad_token_id)
        )
        return last_hidden[special_mask]
    return collate


target_models = {
    "ModernBERT": "answerdotai/ModernBERT-base",
    "BERT": "bert-base-uncased",
    "RoBERTa": "roberta-base",
    "ELECTRA": "google/electra-base-discriminator"
}


EPOCHS = 20
BATCH_SIZE = 64
NUM_EMBEDDINGS = 8000

for model_name, model_id in target_models.items():
    print(f"\n{'='*40}")
    print(f"Processing: {model_name}")
    print(f"{'='*40}")

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        base_model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device)
        base_model.eval()
        hidden_size = base_model.config.hidden_size
    except:
        print(f"Skipping {model_name} (Error)")
        continue

    vqvae = VQVAE_Original_Structure(
        input_dim=hidden_size,
        hidden_dim=518,
        num_embeddings=NUM_EMBEDDINGS,
        embedding_dim=64
    ).to(device)

    optimizer = optim.Adam(vqvae.parameters(), lr=1e-3)

    # Loader Ï§ÄÎπÑ
    collate_fn = get_collate_fn(tokenizer, base_model, device)
    train_ds = EmbeddingDataset(tokenizer, train_sentences)
    val_ds = EmbeddingDataset(tokenizer, val_sentences)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0)

    # K-Means Init
    print("  [Init] K-Means Clustering...")
    init_data = []
    with torch.no_grad():
        for batch_emb in train_loader:
            z_body = vqvae.encoder_body(batch_emb)
            z_res = vqvae.encoder_residual(batch_emb)
            z_e = (z_body + z_res) * vqvae.z_scale
            init_data.append(z_e.cpu().numpy())
            if len(init_data) * BATCH_SIZE * 20 > 50000: break

    init_data = np.concatenate(init_data, axis=0)
    kmeans = MiniBatchKMeans(n_clusters=NUM_EMBEDDINGS, n_init=1, batch_size=4096).fit(init_data)

    vqvae.codebook.weight.data.copy_(torch.from_numpy(kmeans.cluster_centers_).to(device))
    vqvae.embedding_avg.data.copy_(torch.from_numpy(kmeans.cluster_centers_).to(device))
    vqvae.cluster_size.data.fill_(1.0)
    print("    -> Done.")


    best_val_loss = float('inf')

    # Training Loop
    print(f"  [Train] Start Training...")

    for epoch in range(EPOCHS):
        # 1. Train
        vqvae.train()
        total_recon = 0
        count = 0
        pbar = tqdm(train_loader, desc=f"Ep {epoch+1} [Train]", leave=False)
        for batch_emb in pbar:
            optimizer.zero_grad()
            out = vqvae(batch_emb)
            out['loss'].backward()
            optimizer.step()
            total_recon += out['recon_loss'].item()
            count += 1
            pbar.set_postfix({'Recon': f"{out['recon_loss'].item():.4f}"})
        avg_train_loss = total_recon / count

        # 2. Validation
        vqvae.eval()
        val_recon = 0
        val_count = 0
        with torch.no_grad():
            for batch_emb in val_loader:
                out = vqvae(batch_emb)
                val_recon += out['recon_loss'].item()
                val_count += 1
        avg_val_loss = val_recon / max(val_count, 1)

        alive_count = (vqvae.cluster_size > 1e-3).sum().item()
        print(f"    Ep {epoch+1} | Train MSE: {avg_train_loss:.4f} | Val MSE: {avg_val_loss:.4f} | Alive: {alive_count}")

        # Best Model Ï†ÄÏû• Î°úÏßÅ (ÏóêÌè¨ÌÅ¨ Ìè¨Ìï®)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # Ïù¥Î¶Ñ ÌòïÏãù: VQVAE_OrigStruct_{Î™®Îç∏Î™Ö}_Ep{ÏóêÌè¨ÌÅ¨}.pth
            save_name = f"VQVAE_OrigStruct_{model_name}_Ep{epoch+1}.pth"
            save_path = os.path.join(SAVE_DIR, save_name)

            torch.save(vqvae.state_dict(), save_path)
            print(f"      üíæ New Best Saved: {save_name} (Loss: {best_val_loss:.4f})")

    del vqvae, base_model, tokenizer, optimizer, train_loader, val_loader
    torch.cuda.empty_cache()
    gc.collect()

Mounted at /content/drive


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'li2017dailydialog/daily_dialog' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'li2017dailydialog/daily_dialog' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading Data...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


default/train/0000.parquet:   0%|          | 0.00/3.61M [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/334k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/331k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Parsing:   0%|          | 0/13118 [00:00<?, ?it/s]

Total Sentences: 153945 | Train: 138550 | Val: 15395

Processing: ModernBERT


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/599M [00:00<?, ?B/s]

  [Init] K-Means Clustering...


  return torch._C._get_cublas_allow_tf32()
W0115 08:56:26.069000 1952 torch/_inductor/utils.py:1558] [1/0_1] Not enough SMs to use max_autotune_gemm mode


    -> Done.
  [Train] Start Training...


Ep 1 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 1 | Train MSE: 0.4837 | Val MSE: 0.3310 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep1.pth (Loss: 0.3310)


Ep 2 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 2 | Train MSE: 0.3128 | Val MSE: 0.3060 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep2.pth (Loss: 0.3060)


Ep 3 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 3 | Train MSE: 0.3025 | Val MSE: 0.3031 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep3.pth (Loss: 0.3031)


Ep 4 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 4 | Train MSE: 0.3013 | Val MSE: 0.3025 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep4.pth (Loss: 0.3025)


Ep 5 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 5 | Train MSE: 0.3010 | Val MSE: 0.3025 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep5.pth (Loss: 0.3025)


Ep 6 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 6 | Train MSE: 0.3008 | Val MSE: 0.3021 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep6.pth (Loss: 0.3021)


Ep 7 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 7 | Train MSE: 0.3006 | Val MSE: 0.3020 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep7.pth (Loss: 0.3020)


Ep 8 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 8 | Train MSE: 0.3005 | Val MSE: 0.3019 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep8.pth (Loss: 0.3019)


Ep 9 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 9 | Train MSE: 0.3005 | Val MSE: 0.3019 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep9.pth (Loss: 0.3019)


Ep 10 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 10 | Train MSE: 0.3004 | Val MSE: 0.3019 | Alive: 8000


Ep 11 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 11 | Train MSE: 0.3003 | Val MSE: 0.3024 | Alive: 8000


Ep 12 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 12 | Train MSE: 0.3003 | Val MSE: 0.3017 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep12.pth (Loss: 0.3017)


Ep 13 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 13 | Train MSE: 0.3003 | Val MSE: 0.3016 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep13.pth (Loss: 0.3016)


Ep 14 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 14 | Train MSE: 0.3003 | Val MSE: 0.3019 | Alive: 8000


Ep 15 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 15 | Train MSE: 0.3002 | Val MSE: 0.3018 | Alive: 8000


Ep 16 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 16 | Train MSE: 0.3001 | Val MSE: 0.3018 | Alive: 8000


Ep 17 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 17 | Train MSE: 0.3001 | Val MSE: 0.3015 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep17.pth (Loss: 0.3015)


Ep 18 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 18 | Train MSE: 0.3000 | Val MSE: 0.3017 | Alive: 8000


Ep 19 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 19 | Train MSE: 0.3000 | Val MSE: 0.3014 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ModernBERT_Ep19.pth (Loss: 0.3014)


Ep 20 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 20 | Train MSE: 0.3000 | Val MSE: 0.3015 | Alive: 8000

Processing: BERT


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  [Init] K-Means Clustering...
    -> Done.
  [Train] Start Training...


Ep 1 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 1 | Train MSE: 0.1610 | Val MSE: 0.1293 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep1.pth (Loss: 0.1293)


Ep 2 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 2 | Train MSE: 0.1254 | Val MSE: 0.1238 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep2.pth (Loss: 0.1238)


Ep 3 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 3 | Train MSE: 0.1233 | Val MSE: 0.1232 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep3.pth (Loss: 0.1232)


Ep 4 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 4 | Train MSE: 0.1229 | Val MSE: 0.1230 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep4.pth (Loss: 0.1230)


Ep 5 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 5 | Train MSE: 0.1228 | Val MSE: 0.1229 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep5.pth (Loss: 0.1229)


Ep 6 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 6 | Train MSE: 0.1226 | Val MSE: 0.1228 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep6.pth (Loss: 0.1228)


Ep 7 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 7 | Train MSE: 0.1226 | Val MSE: 0.1227 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep7.pth (Loss: 0.1227)


Ep 8 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 8 | Train MSE: 0.1225 | Val MSE: 0.1227 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep8.pth (Loss: 0.1227)


Ep 9 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 9 | Train MSE: 0.1225 | Val MSE: 0.1227 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep9.pth (Loss: 0.1227)


Ep 10 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 10 | Train MSE: 0.1225 | Val MSE: 0.1226 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep10.pth (Loss: 0.1226)


Ep 11 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 11 | Train MSE: 0.1224 | Val MSE: 0.1226 | Alive: 8000


Ep 12 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 12 | Train MSE: 0.1224 | Val MSE: 0.1226 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep12.pth (Loss: 0.1226)


Ep 13 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 13 | Train MSE: 0.1224 | Val MSE: 0.1226 | Alive: 8000


Ep 14 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 14 | Train MSE: 0.1224 | Val MSE: 0.1226 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep14.pth (Loss: 0.1226)


Ep 15 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 15 | Train MSE: 0.1224 | Val MSE: 0.1225 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep15.pth (Loss: 0.1225)


Ep 16 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 16 | Train MSE: 0.1223 | Val MSE: 0.1225 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep16.pth (Loss: 0.1225)


Ep 17 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 17 | Train MSE: 0.1223 | Val MSE: 0.1225 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep17.pth (Loss: 0.1225)


Ep 18 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 18 | Train MSE: 0.1223 | Val MSE: 0.1225 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep18.pth (Loss: 0.1225)


Ep 19 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 19 | Train MSE: 0.1223 | Val MSE: 0.1225 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep19.pth (Loss: 0.1225)


Ep 20 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 20 | Train MSE: 0.1223 | Val MSE: 0.1224 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_BERT_Ep20.pth (Loss: 0.1224)

Processing: RoBERTa


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  [Init] K-Means Clustering...
    -> Done.
  [Train] Start Training...


Ep 1 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 1 | Train MSE: 0.0392 | Val MSE: 0.0243 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep1.pth (Loss: 0.0243)


Ep 2 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 2 | Train MSE: 0.0223 | Val MSE: 0.0212 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep2.pth (Loss: 0.0212)


Ep 3 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 3 | Train MSE: 0.0208 | Val MSE: 0.0206 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep3.pth (Loss: 0.0206)


Ep 4 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 4 | Train MSE: 0.0204 | Val MSE: 0.0205 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep4.pth (Loss: 0.0205)


Ep 5 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 5 | Train MSE: 0.0203 | Val MSE: 0.0204 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep5.pth (Loss: 0.0204)


Ep 6 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 6 | Train MSE: 0.0203 | Val MSE: 0.0204 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep6.pth (Loss: 0.0204)


Ep 7 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 7 | Train MSE: 0.0203 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep7.pth (Loss: 0.0203)


Ep 8 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 8 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep8.pth (Loss: 0.0203)


Ep 9 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 9 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000


Ep 10 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 10 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000


Ep 11 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 11 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep11.pth (Loss: 0.0203)


Ep 12 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 12 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000


Ep 13 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 13 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep13.pth (Loss: 0.0203)


Ep 14 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 14 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep14.pth (Loss: 0.0203)


Ep 15 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 15 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000


Ep 16 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 16 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep16.pth (Loss: 0.0203)


Ep 17 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 17 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000


Ep 18 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 18 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000


Ep 19 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 19 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep19.pth (Loss: 0.0203)


Ep 20 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 20 | Train MSE: 0.0202 | Val MSE: 0.0203 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_RoBERTa_Ep20.pth (Loss: 0.0203)

Processing: ELECTRA


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

  [Init] K-Means Clustering...


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

    -> Done.
  [Train] Start Training...


Ep 1 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 1 | Train MSE: 0.0922 | Val MSE: 0.0563 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep1.pth (Loss: 0.0563)


Ep 2 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 2 | Train MSE: 0.0527 | Val MSE: 0.0511 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep2.pth (Loss: 0.0511)


Ep 3 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 3 | Train MSE: 0.0505 | Val MSE: 0.0505 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep3.pth (Loss: 0.0505)


Ep 4 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 4 | Train MSE: 0.0502 | Val MSE: 0.0504 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep4.pth (Loss: 0.0504)


Ep 5 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 5 | Train MSE: 0.0502 | Val MSE: 0.0504 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep5.pth (Loss: 0.0504)


Ep 6 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 6 | Train MSE: 0.0501 | Val MSE: 0.0503 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep6.pth (Loss: 0.0503)


Ep 7 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 7 | Train MSE: 0.0501 | Val MSE: 0.0503 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep7.pth (Loss: 0.0503)


Ep 8 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 8 | Train MSE: 0.0501 | Val MSE: 0.0503 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep8.pth (Loss: 0.0503)


Ep 9 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 9 | Train MSE: 0.0501 | Val MSE: 0.0503 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep9.pth (Loss: 0.0503)


Ep 10 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 10 | Train MSE: 0.0501 | Val MSE: 0.0503 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep10.pth (Loss: 0.0503)


Ep 11 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 11 | Train MSE: 0.0501 | Val MSE: 0.0503 | Alive: 8000


Ep 12 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 12 | Train MSE: 0.0501 | Val MSE: 0.0502 | Alive: 8000
      üíæ New Best Saved: VQVAE_OrigStruct_ELECTRA_Ep12.pth (Loss: 0.0502)


Ep 13 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 13 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


Ep 14 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 14 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


Ep 15 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 15 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


Ep 16 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 16 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


Ep 17 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 17 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


Ep 18 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 18 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


Ep 19 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 19 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


Ep 20 [Train]:   0%|          | 0/2165 [00:00<?, ?it/s]

    Ep 20 | Train MSE: 0.0500 | Val MSE: 0.0503 | Alive: 8000


# ÎßàÏßÄÎßâ Layer ÌïôÏäµ coodbook -> vocab_size

In [None]:
import os
import re
import gc
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, concatenate_datasets
from google.colab import drive


# ÎìúÎùºÏù¥Î∏å ÎßàÏö¥Ìä∏
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

SAVE_DIR = '/content/drive/MyDrive/VQVAE_Comparison_Final'
os.makedirs(SAVE_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

target_models = {
    "ModernBERT": "answerdotai/ModernBERT-base",
    "BERT": "bert-base-uncased",
    "RoBERTa": "roberta-base",
    "ELECTRA": "google/electra-base-discriminator"
}

# Îç∞Ïù¥ÌÑ∞ÏÖã Î°úÎìú
print("Loading Data...")
dataset = load_dataset("li2017dailydialog/daily_dialog", revision="refs/convert/parquet")
combined = concatenate_datasets([dataset["train"], dataset["validation"], dataset["test"]])

sentences = []
for dialog in tqdm(combined["dialog"], desc="Parsing"):
    for utt in dialog:
        for s in re.split(r'(?<=[\.\?\!])\s+', utt.strip()):
            if s and len(s.split()) > 2: sentences.append(s)


train_sents, val_sents = train_test_split(sentences, test_size=0.1, random_state=42)
print(f"Total: {len(sentences)} | Train: {len(train_sents)} | Val: {len(val_sents)}")


# Î™®Îç∏ CLass Ï†ïÏùò
class VQVAE_Original_Structure(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=518, num_embeddings=8000, embedding_dim=64):
        super().__init__()
        self.encoder_body = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.LeakyReLU(0.1),
            nn.Linear(hidden_dim, 216), nn.LeakyReLU(0.1),
            nn.Linear(216, 108), nn.LeakyReLU(0.1),
            nn.Linear(108, 64), nn.LeakyReLU(0.1),
            nn.Linear(64, embedding_dim)
        )
        self.encoder_residual = nn.Linear(input_dim, embedding_dim, bias=False)
        self.z_scale = nn.Parameter(torch.tensor(5.0))
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)

class VQVAE_Vocab_Classifier(nn.Module):
    def __init__(self, vqvae_model, vocab_size, embedding_dim=64):
        super().__init__()
        self.vqvae = vqvae_model
        # VQ-VAEÎäî ÌïôÏäµÎêòÏßÄ ÏïäÎèÑÎ°ù Í≥†Ï†ï (Freeze)
        for param in self.vqvae.parameters():
            param.requires_grad = False

        # Codebook Vector -> Original Vocab Token Classifier
        self.classifier = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        with torch.no_grad():
            z_body = self.vqvae.encoder_body(x)
            z_res  = self.vqvae.encoder_residual(x)
            z_e    = (z_body + z_res) * self.vqvae.z_scale

            dists = torch.cdist(z_e, self.vqvae.codebook.weight)
            min_encoding_indices = torch.argmin(dists, dim=1)
            z_q = self.vqvae.codebook(min_encoding_indices)

        logits = self.classifier(z_q)
        return logits


# ÌïôÏäµ

CLASSIFIER_EPOCHS = 10
BATCH_SIZE = 32
LR_CLASSIFIER = 1e-3
NUM_EMBEDDINGS = 8000

def yield_batches(sents, batch_size):
    for i in range(0, len(sents), batch_size):
        yield sents[i:i+batch_size]

print("\n" + "="*60)
print("STARTING VOCAB CLASSIFIER TRAINING")
print("="*60)

for model_name, model_id in target_models.items():
    print(f"\nProcessing >>> [{model_name}]")

    # Base Model Î°úÎìú
    try:
        # ModernBERTÎßå trust_remote_code=True ÌïÑÏöî
        trust = True if "ModernBERT" in model_name else False
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust)
        base_model = AutoModel.from_pretrained(model_id, trust_remote_code=trust).to(device)

        base_model.eval()
        hidden_size = base_model.config.hidden_size
        vocab_size = base_model.config.vocab_size
        print(f"  > Loaded Base Model. Hidden: {hidden_size}, Vocab: {vocab_size}")

    except Exception as e:
        print(f"  ! Error loading Base Model {model_name}: {e}")
        continue

    # VQ-VAE ÏµúÍ≥† ÏÑ±Îä• Î™®Îç∏ ÏûêÎèô Î°úÎìú
    search_pattern = os.path.join(SAVE_DIR, f"VQVAE_OrigStruct_{model_name}_Ep*.pth")
    found_files = glob.glob(search_pattern)

    if not found_files:
        print(f"No VQ-VAE checkpoint found for {model_name}. Skipping...")
        continue

    # ÌååÏùºÎ™ÖÏóêÏÑú 'Ep' Îí§Ïùò Ïà´ÏûêÎ•º Ï∂îÏ∂úÌï¥ Í∞ÄÏû• ÌÅ∞ Í∞í(Í∞ÄÏû• ÏµúÏã† Best Î™®Îç∏) ÏÑ†ÌÉù
    best_vqvae_path = sorted(found_files, key=lambda x: int(re.search(r'Ep(\d+)', x).group(1)))[-1]
    best_filename = os.path.basename(best_vqvae_path)
    print(f" Auto-selected Best VQ-VAE: {best_filename}")

    vqvae = VQVAE_Original_Structure(input_dim=hidden_size, embedding_dim=64, num_embeddings=NUM_EMBEDDINGS).to(device)

    try:
        vqvae.load_state_dict(torch.load(best_vqvae_path, map_location=device), strict=False)
        vqvae.eval()
    except Exception as e:
        print(f" Error loading VQ-VAE weights: {e}")
        continue

    # Classifier Ï§ÄÎπÑ
    model = VQVAE_Vocab_Classifier(vqvae, vocab_size=vocab_size, embedding_dim=64).to(device)
    optimizer = optim.Adam(model.classifier.parameters(), lr=LR_CLASSIFIER)
    criterion = nn.CrossEntropyLoss()

    save_path = os.path.join(SAVE_DIR, f"Classifier_Vocab_{model_name}_Best.pth")
    best_val_acc = 0.0

    # ÌïôÏäµ Î£®ÌîÑ
    print(f" Training Classifier ({CLASSIFIER_EPOCHS} epochs)...")

    for epoch in range(CLASSIFIER_EPOCHS):
        model.train()
        train_loss_sum = 0
        total_batches = 0

        total_len = (len(train_sents) + BATCH_SIZE - 1) // BATCH_SIZE
        pbar = tqdm(yield_batches(train_sents, BATCH_SIZE), desc=f"Ep {epoch+1}", leave=False, total=total_len)

        for batch_sents in pbar:
            inputs = tokenizer(batch_sents, return_tensors='pt', padding=True, truncation=True, max_length=128).to(device)
            input_ids = inputs['input_ids']

            with torch.no_grad():
                outputs = base_model(**inputs)
                embeddings = outputs.last_hidden_state

            # ÎßàÏä§ÌÅ¨ Ï≤òÎ¶¨
            mask = inputs['attention_mask'].view(-1) == 1
            flat_emb = embeddings.view(-1, hidden_size)[mask]
            flat_ids = input_ids.view(-1)[mask]

            # Vocab Size
            valid_mask = flat_ids < vocab_size
            flat_emb = flat_emb[valid_mask]
            flat_ids = flat_ids[valid_mask]

            if flat_emb.size(0) == 0: continue

            optimizer.zero_grad()
            logits = model(flat_emb)
            loss = criterion(logits, flat_ids)
            loss.backward()
            optimizer.step()

            train_loss_sum += loss.item()
            total_batches += 1
            pbar.set_postfix({'Loss': f"{loss.item():.4f}"})

        avg_train_loss = train_loss_sum / max(total_batches, 1)

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch_sents in yield_batches(val_sents, BATCH_SIZE):
                inputs = tokenizer(batch_sents, return_tensors='pt', padding=True, truncation=True, max_length=128).to(device)
                input_ids = inputs['input_ids']

                outputs = base_model(**inputs)
                embeddings = outputs.last_hidden_state

                mask = inputs['attention_mask'].view(-1) == 1
                flat_emb = embeddings.view(-1, hidden_size)[mask]
                flat_ids = input_ids.view(-1)[mask]

                valid_mask = flat_ids < vocab_size
                flat_emb = flat_emb[valid_mask]
                flat_ids = flat_ids[valid_mask]

                if flat_emb.size(0) == 0: continue

                logits = model(flat_emb)
                preds = torch.argmax(logits, dim=1)

                val_correct += (preds == flat_ids).sum().item()
                val_total += flat_ids.size(0)

        val_acc = (val_correct / val_total) * 100 if val_total > 0 else 0

        # Best Model Ï†ÄÏû•
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"Ep {epoch+1} | Loss: {avg_train_loss:.4f} | Val Acc: {val_acc:.2f}% (Best)")
        else:
            print(f"Ep {epoch+1} | Loss: {avg_train_loss:.4f} | Val Acc: {val_acc:.2f}%")

    # Î©îÎ™®Î¶¨ Ìï¥Ï†ú
    del model, vqvae, base_model, tokenizer, optimizer
    torch.cuda.empty_cache()
    gc.collect()

Device: cuda
Loading Data...


Parsing:   0%|          | 0/13118 [00:00<?, ?it/s]

Total: 153945 | Train: 138550 | Val: 15395

STARTING VOCAB CLASSIFIER TRAINING

Processing >>> [ModernBERT]


model.safetensors:   0%|          | 0.00/599M [00:00<?, ?B/s]

  > Loaded Base Model. Hidden: 768, Vocab: 50368
  > üìå Auto-selected Best VQ-VAE: VQVAE_OrigStruct_ModernBERT_Ep19.pth
  > Training Classifier (10 epochs)...


Ep 1:   0%|          | 0/4330 [00:00<?, ?it/s]

  return torch._C._get_cublas_allow_tf32()
W0116 01:30:48.440000 382 torch/_inductor/utils.py:1558] [1/0_1] Not enough SMs to use max_autotune_gemm mode


    Ep 1 | Loss: 1.9191 | Val Acc: 77.96% (Best) üíæ


Ep 2:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 2 | Loss: 1.3315 | Val Acc: 78.54% (Best) üíæ


Ep 3:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 3 | Loss: 1.2529 | Val Acc: 79.48% (Best) üíæ


Ep 4:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 4 | Loss: 1.2139 | Val Acc: 79.55% (Best) üíæ


Ep 5:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 5 | Loss: 1.1892 | Val Acc: 79.56% (Best) üíæ


Ep 6:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 6 | Loss: 1.1717 | Val Acc: 79.58% (Best) üíæ


Ep 7:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 7 | Loss: 1.1583 | Val Acc: 79.57%


Ep 8:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 8 | Loss: 1.1476 | Val Acc: 79.63% (Best) üíæ


Ep 9:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 9 | Loss: 1.1388 | Val Acc: 79.64% (Best) üíæ


Ep 10:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 10 | Loss: 1.1314 | Val Acc: 79.64% (Best) üíæ

Processing >>> [BERT]


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  > Loaded Base Model. Hidden: 768, Vocab: 30522
  > üìå Auto-selected Best VQ-VAE: VQVAE_OrigStruct_BERT_Ep20.pth
  > Training Classifier (10 epochs)...


Ep 1:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 1 | Loss: 2.0491 | Val Acc: 81.23% (Best) üíæ


Ep 2:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 2 | Loss: 1.1585 | Val Acc: 81.96% (Best) üíæ


Ep 3:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 3 | Loss: 1.0623 | Val Acc: 82.09% (Best) üíæ


Ep 4:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 4 | Loss: 1.0170 | Val Acc: 82.21% (Best) üíæ


Ep 5:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 5 | Loss: 0.9893 | Val Acc: 82.25% (Best) üíæ


Ep 6:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 6 | Loss: 0.9700 | Val Acc: 82.26% (Best) üíæ


Ep 7:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 7 | Loss: 0.9557 | Val Acc: 82.27% (Best) üíæ


Ep 8:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 8 | Loss: 0.9445 | Val Acc: 82.29% (Best) üíæ


Ep 9:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 9 | Loss: 0.9355 | Val Acc: 82.30% (Best) üíæ


Ep 10:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 10 | Loss: 0.9280 | Val Acc: 82.30%

Processing >>> [RoBERTa]


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  > Loaded Base Model. Hidden: 768, Vocab: 50265
  > üìå Auto-selected Best VQ-VAE: VQVAE_OrigStruct_RoBERTa_Ep20.pth
  > Training Classifier (10 epochs)...


Ep 1:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 1 | Loss: 2.3045 | Val Acc: 83.94% (Best) üíæ


Ep 2:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 2 | Loss: 1.1724 | Val Acc: 86.07% (Best) üíæ


Ep 3:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 3 | Loss: 0.9960 | Val Acc: 86.51% (Best) üíæ


Ep 4:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 4 | Loss: 0.9216 | Val Acc: 86.64% (Best) üíæ


Ep 5:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 5 | Loss: 0.8785 | Val Acc: 86.70% (Best) üíæ


Ep 6:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 6 | Loss: 0.8492 | Val Acc: 86.71% (Best) üíæ


Ep 7:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 7 | Loss: 0.8274 | Val Acc: 86.72% (Best) üíæ


Ep 8:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 8 | Loss: 0.8102 | Val Acc: 86.74% (Best) üíæ


Ep 9:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 9 | Loss: 0.7963 | Val Acc: 86.76% (Best) üíæ


Ep 10:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 10 | Loss: 0.7848 | Val Acc: 86.76% (Best) üíæ

Processing >>> [ELECTRA]


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

  > Loaded Base Model. Hidden: 768, Vocab: 30522
  > üìå Auto-selected Best VQ-VAE: VQVAE_OrigStruct_ELECTRA_Ep12.pth


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  > Training Classifier (10 epochs)...


Ep 1:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 1 | Loss: 2.2972 | Val Acc: 68.30% (Best) üíæ


Ep 2:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 2 | Loss: 1.5667 | Val Acc: 69.56% (Best) üíæ


Ep 3:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 3 | Loss: 1.4578 | Val Acc: 69.94% (Best) üíæ


Ep 4:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 4 | Loss: 1.4043 | Val Acc: 70.11% (Best) üíæ


Ep 5:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 5 | Loss: 1.3708 | Val Acc: 70.20% (Best) üíæ


Ep 6:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 6 | Loss: 1.3472 | Val Acc: 70.26% (Best) üíæ


Ep 7:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 7 | Loss: 1.3293 | Val Acc: 70.32% (Best) üíæ


Ep 8:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 8 | Loss: 1.3152 | Val Acc: 70.31%


Ep 9:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 9 | Loss: 1.3037 | Val Acc: 70.33% (Best) üíæ


Ep 10:   0%|          | 0/4330 [00:00<?, ?it/s]

    Ep 10 | Loss: 1.2940 | Val Acc: 70.35% (Best) üíæ

All Training Completed.
