# Load model, parameters, performance metrics

In [1]:
import os
import torch
import argparse
from models_fine.vqvae import VQVAE
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torchvision.utils import make_grid
import numpy as np

%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

"""
Utility functions
"""

def load_model(model_filename):
    path = os.getcwd() + '/results/'
    
    if torch.cuda.is_available():
        data = torch.load(path + model_filename,  weights_only=False)
    else:
        data = torch.load(path+model_filename,map_location=lambda storage, loc: storage)
    
    params = data["hyperparameters"]
    
    model = VQVAE(params['n_hiddens'], params['n_residual_hiddens'],
                  params['n_residual_layers'], params['n_embeddings'], 
                  params['embedding_dim'], params['beta']).to(device)

    model.load_state_dict(data['model'])
    
    return model, data


"""
End of utilities
"""

model_filename = 'vqvae_data_thu_jun_19_16_39_23_2025.pth'

model,vqvae_data = load_model(model_filename)


# Load dataset and loaders

In [2]:
import utils
training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders('CIFAR10', 32)

In [3]:
params = vqvae_data['hyperparameters']
params

{'batch_size': 32,
 'n_updates': 30000,
 'n_hiddens': 128,
 'n_residual_hiddens': 32,
 'n_residual_layers': 2,
 'embedding_dim': 64,
 'n_embeddings': 512,
 'beta': 0.25,
 'learning_rate': 0.0003,
 'log_interval': 50,
 'dataset': 'CIFAR10',
 'save': True,
 'filename': 'thu_jun_19_16_39_23_2025'}

## CGA

In [4]:
import torch
import torch.nn.functional as F

def update_codebook_with_cga(model, z_e_all, ei_all, ej_all, usage_counts, bottom_percent=0.6):
    """
    Perform genetic algorithm-based update for underutilized codebook tokens.

    Parameters:
    - model: VQVAE model with .vector_quantization.embedding.weight
    - z_e_all: (N, D) encoder outputs flattened
    - ei_all: (N,) top-1 encoding indices
    - ej_all: (N,) top-2 encoding indices
    - usage_counts: (K,) usage count of each codebook vector
    - bottom_percent: float, percentage of least-used tokens to consider for update
    """
    device = z_e_all.device
    codebook = model.vector_quantization.embedding.weight  # (K, D)
    K = codebook.shape[0]

    # Identify bottom X% underutilized codeword indices
    num_bottom = int(K * bottom_percent)
    _, sorted_indices = torch.sort(usage_counts)
    underutilized_tokens = sorted_indices[:num_bottom].tolist()

    # For collecting e_lo vectors for each underutilized codeword
    update_vectors = {k: [] for k in underutilized_tokens}

    for z_e, ei, ej in zip(z_e_all, ei_all, ej_all):
        if ei.item() not in underutilized_tokens:
            continue  # skip if token is not underutilized

        e_i = codebook[ei]  # (D,)
        e_j = codebook[ej]  # (D,)

        # Crossover
        alpha = torch.rand(1).item()
        e_cross = alpha * e_i + (1 - alpha) * e_j

        # Mutation
        beta = torch.empty(1).uniform_(-2, 2).item()
        e_mut = beta * e_cross

        # Local Search
        epsilon = torch.randn_like(e_mut)
        s = 1.0
        e_lo = e_mut + s * epsilon
        dist_ei = F.mse_loss(z_e, e_i)
        dist_lo = F.mse_loss(z_e, e_lo)

        # Find smallest s such that distance condition met
        tries = 10
        while dist_lo >= dist_ei and tries > 0:
            s *= 0.5
            e_lo = e_mut + s * epsilon
            dist_lo = F.mse_loss(z_e, e_lo)
            tries -= 1

        update_vectors[ei.item()].append(e_lo)

    # Token Update: average e_lo vectors for each token
    with torch.no_grad():
        for k in underutilized_tokens:
            if update_vectors[k]:
                update_vectors_tensor = torch.stack(update_vectors[k], dim=0)
                model.vector_quantization.embedding.weight[k] = update_vectors_tensor.mean(dim=0)


In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import utils

parser = argparse.ArgumentParser()

"""
Hyperparameters
"""
timestamp = utils.readable_timestamp()

parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--n_updates", type=int, default=10)
parser.add_argument("--n_hiddens", type=int, default=128)
parser.add_argument("--n_residual_hiddens", type=int, default=32)
parser.add_argument("--n_residual_layers", type=int, default=2)
parser.add_argument("--embedding_dim", type=int, default=64)
parser.add_argument("--n_embeddings", type=int, default=512)
parser.add_argument("--beta", type=float, default=.25)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--log_interval", type=int, default=50)
parser.add_argument("--dataset",  type=str, default='CIFAR10')

# whether or not to save model
parser.add_argument("-save", action="store_true")
parser.add_argument("--filename",  type=str, default=timestamp)

args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.save = True
if args.save:
    print('Results will be saved in ./results/vqvae_' + args.filename + '.pth')


training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders(
    args.dataset, args.batch_size)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True)


results = {
    'n_updates': 0,
    'recon_errors': [],
    'loss_vals': [],
    'perplexities': [],
}
def train():
    model.train()
    step = 0

    for epoch in range(args.n_updates):
        all_z_e = []
        all_ei = []
        all_ej = []

        for (x, _) in training_loader:
            x = x.to(device)
            optimizer.zero_grad()

            # Forward pass with z_e, top-1, top-2 indices
            embedding_loss, x_hat, perplexity, z_e, ei, ej = model(x)
            recon_loss = torch.mean((x_hat - x) ** 2) / x_train_var
            loss = recon_loss + embedding_loss

            loss.backward()
            optimizer.step()

            # 기록
            results["recon_errors"].append(recon_loss.item())
            results["perplexities"].append(perplexity.item())
            results["loss_vals"].append(loss.item())
            results["n_updates"] = step

            # CGA 준비를 위한 데이터 누적
            z_e_flat = z_e.permute(0, 2, 3, 1).reshape(-1, z_e.shape[1])  # (BHW, D)
            all_z_e.append(z_e_flat.detach().cpu())
            all_ei.append(ei.view(-1).detach().cpu())
            all_ej.append(ej.view(-1).detach().cpu())

            # 로그
            if step % args.log_interval == 0:
                print(f"[Epoch {epoch}] Step {step} | Recon: {recon_loss.item():.4f} | "
                      f"Loss: {loss.item():.4f} | Perplexity: {perplexity.item():.2f}")

            step += 1

        # === ✅ 1 에폭 끝난 후 CGA 업데이트 ===
        print(f"[Epoch {epoch}] ⏳ Running CGA update for underutilized tokens...")
        z_e_all = torch.cat(all_z_e, dim=0).to(device)
        ei_all = torch.cat(all_ei, dim=0)
        ej_all = torch.cat(all_ej, dim=0)

        usage_counts = torch.bincount(ei_all, minlength=model.vector_quantization.n_e)
        update_codebook_with_cga(model, z_e_all, ei_all, ej_all, usage_counts)

        print(f"[Epoch {epoch}] ✅ CGA update complete.\n")

train()

Results will be saved in ./results/vqvae_c:\Users\minju\AppData\Roaming\jupyter\runtime\kernel-v3d251d7a30161ce0661e476edaa1ba3b61613a5c0.json.pth
[Epoch 0] Step 0 | Recon: 0.3749 | Loss: 0.4349 | Perplexity: 126.96
[Epoch 0] Step 50 | Recon: 0.4133 | Loss: 0.4919 | Perplexity: 130.77
[Epoch 0] Step 100 | Recon: 0.3449 | Loss: 0.3903 | Perplexity: 132.32
[Epoch 0] Step 150 | Recon: 0.3486 | Loss: 0.3957 | Perplexity: 135.01
[Epoch 0] Step 200 | Recon: 0.3378 | Loss: 0.3843 | Perplexity: 128.64
[Epoch 0] Step 250 | Recon: 0.3572 | Loss: 0.4067 | Perplexity: 129.96
[Epoch 0] Step 300 | Recon: 0.3560 | Loss: 0.4068 | Perplexity: 141.52
[Epoch 0] Step 350 | Recon: 0.3494 | Loss: 0.3960 | Perplexity: 129.94
[Epoch 0] Step 400 | Recon: 0.3611 | Loss: 0.4026 | Perplexity: 130.17
[Epoch 0] Step 450 | Recon: 0.3323 | Loss: 0.3774 | Perplexity: 137.04
[Epoch 0] Step 500 | Recon: 0.3432 | Loss: 0.3961 | Perplexity: 139.54
[Epoch 0] Step 550 | Recon: 0.3567 | Loss: 0.4060 | Perplexity: 131.95
[Epo