# PoC: Logarithmic Query Number is Sufficient for ZOO

## Prepare

In [None]:
import torch
import torch.nn as nn
import torch.optim
import torch.utils.data

import torch.nn.utils.prune as prune

import numpy as np
import math
import os
import time
import matplotlib.pyplot as plt

from import_shelf import shelf
from shelf.models.transformer import VisionTransformer
from shelf.dataloaders.cifar import get_CIFAR10_dataset
from shelf.trainers.zeroth_order import learning_rate_estimate_second_order
from shelf.trainers.classic import validate

from tqdm import tqdm


In [None]:
### HYPERPARAMS ###

EPOCHS = 200
BATCH_SIZE = 512
IMAGE_SIZE = 32
PATCH_SIZE = 4
NUM_CLASSES = 10

LR_MAX = 1e-2
LR_MIN = 1e-5
SMOOTHING = 5e-4
QUERY_BASE = 1.05
NUM_QUERY = 1
MOMENTUM = 0.8

MODEL_CONFIG = {
    "dim": 128,
    "depth": 4,
    "heads": 2,
    "mlp_dim": 128,
    "dropout": 0.1,
    "emb_dropout": 0.1,
}
# MODEL_CONFIG = {
#     "dim": 512,
#     "depth": 4,
#     "heads": 6,
#     "mlp_dim": 256,
#     "dropout": 0.1,
#     "emb_dropout": 0.1,
# }


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PATH_MODEL = './saves/zoo_poc.pth'


In [None]:
### DATA LOADING ###

train_loader, val_loader = get_CIFAR10_dataset(batch_size=BATCH_SIZE)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

plt.figure(figsize=(10, 1))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(train_loader.dataset.data[i])
    plt.title(classes[train_loader.dataset.targets[i]])
    plt.axis('off')

In [None]:
### MODEL ###

model = VisionTransformer(
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    **MODEL_CONFIG
).to(DEVICE)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_groups = int(np.log(num_params)/np.log(QUERY_BASE))
print(f"Number of parameters: {num_params}")
print(f"Number of groups: {num_groups}")

In [None]:
### OTHERS ###

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, eta_min=1e-5)

In [None]:
@torch.no_grad()
def gradient_estimate_groupwise(input, label, model, criterion, num_groups, group_dict, query=1, smoothing=5e-4):
    model.eval()

    state_dict = model.state_dict()

    # Prepare the result dictionary
    result_gradient = {}
    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        result_gradient[name] = torch.zeros_like(param.data)

    # Original loss
    loss_original = criterion(model(input), label)

    # Perturb and measure the loss
    for g in range(num_groups):
        for q in range(query):
            estimated_gradient = {}

            # perturb the model
            for name, param in model.named_parameters():
                if not param.requires_grad: continue
                estimated_gradient[name] = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                
                # handle pruned parameters
                if '_orig' in name and name.replace('_orig', '_mask') in state_dict:
                    mask = state_dict[name.replace('_orig', '_mask')]
                    estimated_gradient[name] *= mask
                
                # handle grouping
                estimated_gradient[group_dict[name] != g] = 0
                
                # add the perturbation
                param.data += estimated_gradient[name] * smoothing
            
            # measure the loss
            loss_perturbed = criterion(model(input), label)

            # restore the model
            for name, param in model.named_parameters():
                if not param.requires_grad: continue
                param.data -= estimated_gradient[name] * smoothing
            
            # accumulate the gradient
            loss_difference = (loss_perturbed - loss_original) / smoothing
            for name, param in model.named_parameters():
                if not param.requires_grad: continue
                result_gradient[name] += loss_difference * estimated_gradient[name]
    
    # Average the gradient
    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        result_gradient[name] /= query
    
    return result_gradient


In [None]:
def calc_r(N, d):
    equation = np.poly1d([1] + [0 for _ in range(N-1)] + [-d, d-1], False)
    roots = np.roots(equation)
    roots = roots[np.isreal(roots)]
    r = np.real(np.max(roots))

    if r <= 1:
        raise ValueError("r must be greater than 1")

    return r

def group_by_gradient_exp(estimated_gradient, num_groups):
    all_gradients = torch.cat([grad.flatten() for grad in estimated_gradient.values()]).abs()
    num_params = all_gradients.size(0)

    # Calculate r
    r = calc_r(num_groups, num_params)

    # Fine milestones
    milestones = []
    for group_idx in range(num_groups):
        group_ratio = (r ** (group_idx + 1) - 1) / (r - 1) / num_params
        milestones.append(torch.quantile(all_gradients, group_ratio, interpolation='lower'))
    milestones[-1] = torch.quantile(all_gradients, 1.0)

    # Group the parameters
    group_dict = {}
    for name, grad in estimated_gradient.items():
        group_dict[name] = torch.zeros_like(grad)
        for group_idx, milestone in enumerate(milestones):
            group_dict[name][grad.abs() > milestone] = group_idx

    return group_dict

In [None]:
def train_zo(
        train_loader, model, criterion, optimizer, epoch,
        smoothing=1e-3, query=1, lr_auto=True, lr_max=1e-2, lr_min=1e-5, momentum=0.9,
        num_groups=1, group_dict=None,
        config=None
    ):
    model.eval()

    # Prepare statistics
    num_data = 0
    num_correct = 0
    sum_loss = 0
    num_query = 0
    
    lr_history = []

    # Prepare grouping
    if num_groups > 1 and group_dict is None:
        group_dict = {}
        for name, param in model.named_parameters():
            if not param.requires_grad: continue
            group_dict[name] = torch.randint(0, num_groups, size=param.data.size(), device=param.data.device, dtype=torch.long)
    
    gradient_momentum = {}
    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        gradient_momentum[name] = torch.zeros_like(param.data)

    # Train the model
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}', leave=False)
    for input, label in pbar:
        input = input.cuda()
        label = label.cuda()

        optimizer.zero_grad()

        # Gradient estimation
        estimated_gradient = gradient_estimate_groupwise(input, label, model, criterion, num_groups, group_dict, query, smoothing)
        num_query += query * num_groups

        # Update momentum
        for name, param in model.named_parameters():
            if not param.requires_grad: continue
            gradient_momentum[name] = momentum * gradient_momentum[name] + estimated_gradient[name]

        # Apply gradient
        for name, param in model.named_parameters():
            if not param.requires_grad: continue
            param.grad = estimated_gradient[name]

        # Estimate learning rate
        lr = lr_min
        if lr_auto:
            lr = learning_rate_estimate_second_order(input, label, model, criterion, estimated_gradient, smoothing=smoothing)
            lr = abs(lr.item()) if lr != 0 else lr_min

            num_query += 3

        lr = min(lr, lr_max)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        lr_history.append(lr)

        # Update group
        group_dict = group_by_gradient_exp(gradient_momentum, num_groups)
        
        # Update the model
        optimizer.step()

        # Statistics
        output = model(input)
        loss = criterion(output, label)

        _, predicted = torch.max(output.data, 1)
        num_data += label.size(0)
        num_correct += (predicted == label).sum().item()
        sum_loss += loss.item() * label.size(0)
    
        accuracy = num_correct / num_data
        avg_loss = sum_loss / num_data

        pbar.set_postfix(train_accuracy=accuracy, train_loss=avg_loss)
        
    accuracy = num_correct / num_data
    avg_loss = sum_loss / num_data

    if config is not None:
        config['lr_avg'] = np.mean(lr_history)
        config['lr_std'] = np.std(lr_history)
        config['num_query'] = num_query
        config['group_dict'] = group_dict

    return accuracy, avg_loss

In [None]:
group_dict = None

for epoch in range(EPOCHS):
    start_time = time.time()

    config = {}

    num_groups = int(np.log(num_params)/np.log(QUERY_BASE))
    
    train_acc, train_loss = train_zo(
        train_loader, model, criterion, optimizer, epoch,
        smoothing=SMOOTHING, query=NUM_QUERY, lr_auto=True, lr_max=LR_MAX, lr_min=LR_MIN, momentum=MOMENTUM,
        num_groups=num_groups, group_dict=group_dict,
        config=config
    )
    val_acc, val_loss = validate(val_loader, model, criterion, epoch)

    lr_avg = config['lr_avg']
    lr_std = config['lr_std']
    group_dict = config['group_dict']

    print(
        f"Epoch {epoch+1:3d}/{EPOCHS}, "
        f"LR: {lr_avg:.6f}±{lr_std:.6f} | "
        f"Train Acc: {train_acc * 100:.2f}%, "
        f"Train Loss: {train_loss:.4f}, "
        f"Val Acc: {val_acc*100:.2f}%, "
        f"Val Loss: {val_loss:.4f} | "
        f"Num Query: {config['num_query']}, "
        f"Time: {time.time() - start_time:.3f}s"
    )

torch.save(model.state_dict(), PATH_MODEL)