MobileNetV2 on CIFAR-10 dataset with pruning

In [1]:
import argparse
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import OrderedDict
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

In [2]:
# ---------------------------------
# Reproducibility Configuration
# ---------------------------------
import torch
import numpy as np
import random
import os

def set_seed(seed=42):
    """
    Sets seed for reproducibility across:
    - Python
    - NumPy
    - PyTorch (CPU + GPU)
    - cuDNN (deterministic)
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Make cuDNN deterministic (slower but reproducible)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Ensure hash-based ops are deterministic
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set to: {seed}")

# Call the seed function
set_seed(42)

Random seed set to: 42


In [3]:
def get_prunable_convs(model):
    """
    Return an ordered list of (name, module) for Conv2d layers we want to prune.
    Here: only pointwise 1x1 convs with groups = 1.
    """
    prunable = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d) and m.kernel_size == (1, 1) and m.groups == 1:
            prunable.append((name, m))
    return prunable

In [4]:
def channel_importance_l2(conv: nn.Conv2d):
    """
    Compute L2 norm of each output channel in a Conv2d layer.
    Returns a tensor of shape (out_channels,).
    """
    W = conv.weight.detach()   # [out_c, in_c, k_h, k_w]
    # Flatten spatial + input dims, then compute L2 per output channel
    W_flat = W.view(W.size(0), -1)
    importance = torch.norm(W_flat, p=2, dim=1)
    return importance

In [5]:
from torch.utils.data import Subset

def make_calib_loader(test_dataset, batch_size=128, num_samples=1024):
    indices = list(range(min(num_samples, len(test_dataset))))
    calib_ds = Subset(test_dataset, indices)
    calib_loader = torch.utils.data.DataLoader(
        calib_ds, batch_size=batch_size, shuffle=False, num_workers=2
    )
    return calib_loader

In [6]:
def count_params_per_layer(convs):
    """
    convs: list of (name, conv_module)
    Returns dict {name: num_params_in_layer}
    """
    param_counts = {}
    for name, conv in convs:
        param_counts[name] = conv.weight.numel() + (conv.bias.numel() if conv.bias is not None else 0)
    return param_counts

In [7]:
def layer_cap(name: str) -> float:
    """
    Per-layer max prune fraction based on MobileNetV2 block index.
    """
    if not name.startswith("features."):
        return 0.0

    parts = name.split(".")
    try:
        block_idx = int(parts[1])
    except (ValueError, IndexError):
        return 0.0

    # Final big conv: still conservative
    if block_idx == 18:
        return 0.20  # was 0.15

    # Very early feature extractor: don't touch much
    if block_idx <= 3:
        return 0.10

    # Early-mid (4–7): slightly higher
    if block_idx <= 7:
        return 0.20   # was 0.15

    # Mid (8–13): can be more aggressive
    if block_idx <= 13:
        return 0.30   # was 0.20

    # Late (14–17): even more room
    if block_idx <= 17:
        return 0.35   # was 0.25

    return 0.0

In [8]:
def compute_prunability_scores(sensitivities, eps=1e-3, min_sens=1e-3):
    scores = {}
    for name, s in sensitivities.items():
        s_clamped = max(s, min_sens)   # avoid s=0
        scores[name] = 1.0 / (eps + s_clamped)
    return scores

In [9]:
def count_total_and_nonzero_params(model):
    total = 0
    nonzero = 0
    for p in model.parameters():
        if not p.requires_grad:
            continue
        numel = p.numel()
        nz = (p != 0).sum().item()
        total += numel
        nonzero += nz
    sparsity = 1.0 - (nonzero / total)
    return total, nonzero, sparsity

In [10]:
def allocate_prune_fractions(
    scores,
    param_counts,
    target_global_prune=0.25,  # start more conservatively!
):
    names = list(scores.keys())
    scores_vec = torch.tensor([scores[n] for n in names], dtype=torch.float32)
    params_vec = torch.tensor([param_counts[n] for n in names], dtype=torch.float32)
    caps_vec = torch.tensor([layer_cap(n) for n in names], dtype=torch.float32)

    total_params = params_vec.sum().item()

    if target_global_prune <= 0 or scores_vec.sum() == 0:
        return {n: 0.0 for n in names}

    def global_prune_for_alpha(alpha):
        frac = torch.clamp(alpha * scores_vec, max=caps_vec)
        pruned_params = (frac * params_vec).sum()
        return (pruned_params / total_params).item()

    low, high = 0.0, 1e6
    for _ in range(50):
        mid = 0.5 * (low + high)
        g = global_prune_for_alpha(mid)
        if g > target_global_prune:
            high = mid
        else:
            low = mid

    alpha_opt = low
    frac_vec = torch.clamp(alpha_opt * scores_vec, max=caps_vec)

    prune_fracs = {n: float(frac_vec[i].item()) for i, n in enumerate(names)}
    return prune_fracs

In [11]:
def build_optimizer_and_scheduler(model, base_lr, weight_decay, stage_epochs):
    decay, no_decay = [], []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if param.ndimension() == 1 or name.endswith(".bias"):
            no_decay.append(param)
        else:
            decay.append(param)

    optimizer = optim.SGD(
        [
            {"params": decay, "weight_decay": weight_decay},
            {"params": no_decay, "weight_decay": 0.0},
        ],
        lr=base_lr,
        momentum=0.9,
        nesterov=True,
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=stage_epochs, eta_min=1e-4
    )
    return optimizer, scheduler

In [12]:
def create_mobilenetv2_cifar10(num_classes=10, pretrained=True):
    """
    Create a MobileNetV2 model adapted for CIFAR-10.
    """
    if pretrained:
        print("Load ImageNet weights")
        weights = MobileNet_V2_Weights.IMAGENET1K_V1
        model = mobilenet_v2(weights=weights)
    else:
        model = mobilenet_v2(weights=None)

    # Replace classifier (last linear layer) for CIFAR-10
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, num_classes)

    return model


In [13]:
def get_cifar10_loaders(data_dir, batch_size=128, num_workers=4):
    """
    Returns (trainloader, testloader) for CIFAR-10 with good augmentations.
    """
    # ImageNet-like normalization
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    # CIFAR-10 images are 32x32; we upscale to 224x224 for MobileNetV2
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        normalize,
        transforms.RandomErasing(p=0.2)
    ])

    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=train_transform
    )

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=test_transform
    )

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    return trainloader, testloader


In [14]:
class LabelSmoothingCrossEntropy(nn.Module):
    """
    Cross-entropy with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, preds, target):
        num_classes = preds.size(1)
        log_preds = torch.log_softmax(preds, dim=1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_preds)
            true_dist.fill_(self.smoothing / (num_classes - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing)
        return torch.mean(torch.sum(-true_dist * log_preds, dim=1))


def accuracy(output, target, topk=(1,)):
    """
    Computes the top-k accuracy for the specified values of k.
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        # Get top-k indices
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        # Compare with targets expanded
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


In [15]:
def train_one_epoch(model, criterion, optimizer, dataloader, device, epoch, scaler=None):
    model.train()
    running_loss = 0.0
    running_top1 = 0.0
    total = 0

    start_time = time.time()

    for i, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()

        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        top1, = accuracy(outputs, targets, topk=(1,))
        bs = targets.size(0)
        running_loss += loss.item() * bs
        running_top1 += top1.item() * bs
        total += bs

        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch}] Step [{i+1}/{len(dataloader)}] "
                f"Loss: {running_loss / total:.4f} | "
                f"Top-1: {running_top1 / total:.2f}%"
            )

    epoch_loss = running_loss / total
    epoch_acc1 = running_top1 / total
    elapsed = time.time() - start_time
    print(
        f"Epoch [{epoch}] TRAIN - "
        f"Loss: {epoch_loss:.4f} | Top-1: {epoch_acc1:.2f}% | "
        f"Time: {elapsed:.1f}s"
    )
    return epoch_loss, epoch_acc1


@torch.no_grad()
def evaluate(model, criterion, dataloader, device, epoch="TEST"):
    model.eval()
    running_loss = 0.0
    running_top1 = 0.0
    total = 0

    for inputs, targets in dataloader:
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        top1, = accuracy(outputs, targets, topk=(1,))
        bs = targets.size(0)
        running_loss += loss.item() * bs
        running_top1 += top1.item() * bs
        total += bs

    loss = running_loss / total
    acc1 = running_top1 / total
    print(f"Epoch [{epoch}] VALID - Loss: {loss:.4f} | Top-1: {acc1:.2f}%")
    return loss, acc1


In [16]:
def estimate_layer_sensitivity(model, criterion, device, calib_loader, base_prune=0.3):
    """
    Estimate sensitivity of each prunable conv layer using your existing evaluate().
    base_prune: fraction of channels to prune *only for sensitivity testing*.
    Returns:
        sensitivities: dict {layer_name: accuracy_drop (fraction, e.g. 0.01 = 1%)}
        convs: ordered list of (name, conv_module)
        baseline_acc: baseline accuracy on calib set (fraction)
    """
    model.eval()
    convs = get_prunable_convs(model)

    # Baseline accuracy on calibration set
    loss, baseline_acc = evaluate(model, criterion, calib_loader, device, epoch="CALIB_BASE")
    baseline_acc /= 100.0  # convert % -> fraction
    print(f"Baseline calib accuracy: {baseline_acc*100:.2f}%")

    sensitivities = OrderedDict()

    for name, conv in convs:
        print(f"Testing sensitivity for layer: {name}")
        imp = channel_importance_l2(conv)
        num_channels = imp.numel()
        num_prune = int(base_prune * num_channels)
        if num_prune < 1:
            sensitivities[name] = 0.0
            continue

        prune_idxs = torch.argsort(imp)[:num_prune]

        # Backup weights (and bias)
        W_orig = conv.weight.data.clone()
        b_orig = conv.bias.data.clone() if conv.bias is not None else None

        # Temporary pruning
        with torch.no_grad():
            conv.weight.data[prune_idxs] = 0
            if conv.bias is not None:
                conv.bias.data[prune_idxs] = 0

        # Evaluate with this layer pruned
        _, acc1_pruned = evaluate(model, criterion, calib_loader, device, epoch=f"CALIB_{name}")
        acc_pruned = acc1_pruned / 100.0
        acc_drop = baseline_acc - acc_pruned
        sensitivities[name] = float(acc_drop)
        print(f"  Acc with layer pruned: {acc_pruned*100:.2f}% (drop {acc_drop*100:.2f}%)")

        # Restore weights
        with torch.no_grad():
            conv.weight.data.copy_(W_orig)
            if conv.bias is not None:
                conv.bias.data.copy_(b_orig)

    return sensitivities, convs, baseline_acc

In [17]:
def apply_structured_channel_pruning(model, convs, prune_fracs):
    """
    Zero entire output channels in selected Conv2d layers according to prune_fracs.
    convs: list of (name, conv_module)
    prune_fracs: dict {layer_name: fraction_to_prune}
    """
    model.eval()
    with torch.no_grad():
        for name, conv in convs:
            p = prune_fracs.get(name, 0.0)
            if p <= 0.0:
                continue

            imp = channel_importance_l2(conv)
            num_channels = imp.numel()
            num_prune = int(p * num_channels)
            if num_prune < 1:
                continue

            prune_idxs = torch.argsort(imp)[:num_prune]

            # Zero selected output channels
            conv.weight.data[prune_idxs] = 0
            if conv.bias is not None:
                conv.bias.data[prune_idxs] = 0

    return model

In [18]:
class Config:
    data_dir = "./data"           # For Colab, use "/content/data"
    epochs = 3
    batch_size = 128
    lr = 0.05
    weight_decay = 4e-5
    num_workers = 4
    label_smoothing = 0.1
    no_pretrained = False         # Set True to disable ImageNet pretraining
    save_path = "mobilenetv2_cifar10_best_pruned.pth"
    resume = ""                   # Path to checkpoint, or "" to start fresh
    mixed_precision = False        # Use AMP if GPU is available
    is_pruning = True             # Use this to detect pruning experiment is in progress

cfg = Config()
print("Config:", vars(cfg))


Config: {}


In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Data
trainloader, testloader = get_cifar10_loaders(
    data_dir=cfg.data_dir,
    batch_size=cfg.batch_size,
    num_workers=cfg.num_workers,
)

# Model
model = create_mobilenetv2_cifar10(
    num_classes=10,
    pretrained=not cfg.no_pretrained,
)
ckpt = torch.load("/kaggle/input/fp32-trained-model/mobilenetv2_cifar10_best_baseline.pth", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"], strict=True)
model = model.to(device)

# Dense model parameters for compression computation
dense_total, dense_nonzero, dense_sparsity = count_total_and_nonzero_params(model)
print(f"Dense model - total params: {dense_total}, nonzero: {dense_nonzero}, sparsity: {dense_sparsity*100:.2f}%")

criterion = LabelSmoothingCrossEntropy(smoothing=cfg.label_smoothing)

val_loss, val_acc1 = evaluate(model, criterion, testloader, device, epoch="BASELINE")
print(f"Baseline accuracy: {val_acc1:.2f}%")

# Create a small calibration loader from your test dataset
calib_loader = make_calib_loader(testloader.dataset, batch_size=128, num_samples=1024)

# Estimate layer sensitivities
sensitivities, convs, baseline_acc = estimate_layer_sensitivity(
    model,
    criterion,
    device,
    calib_loader,
    base_prune=0.2,
)

# Compute prunability scores
scores = compute_prunability_scores(sensitivities)
param_counts = count_params_per_layer(convs)

# Targets are "global prune fractions" on the prunable params
stage_targets = [0.10, 0.15, 0.2, 0.25, 0.3]
val_acc_threshold = 95.0  # stop if we go below this after fine-tune
stage_epochs = 5          # fine-tune epochs per stage
base_lr_ft = 0.01

best_overall_acc = val_acc1
best_state_dict  = model.state_dict()
best_pruned_acc = 0.0
best_pruned_state = None
best_pruned_sparsity = 0.0
best_pruned_stage = None
stage_results = []  # list of dicts: one entry per stage


for stage_idx, target in enumerate(stage_targets, 1):
    print("=" * 60)
    print(f"Stage {stage_idx}: target_global_prune = {target:.2f}")
    print("=" * 60)

    # 1) Allocate prune fractions for this target
    prune_fracs = allocate_prune_fractions(
        scores,
        param_counts,
        target_global_prune=target,
    )

    print("Per-layer prune fractions at this stage:")
    for name, frac in prune_fracs.items():
        print(f"{name:25s} -> {frac:.2f}")

    # 2) Apply pruning on top of current model
    model = apply_structured_channel_pruning(model, convs, prune_fracs)

    # 3) Evaluate immediately after pruning
    _, acc_raw = evaluate(
        model, criterion, testloader, device, epoch=f"PRUNED_RAW_S{stage_idx}"
    )
    print(f"Stage {stage_idx} - accuracy after pruning, before fine-tune: {acc_raw:.2f}%")

    # 4) Measure sparsity & compression
    total_params, nonzero_params, sparsity = count_total_and_nonzero_params(model)
    compression_ratio = total_params / nonzero_params
    print(
        f"Stage {stage_idx} - sparsity: {sparsity*100:.2f}% "
        f"(compression {compression_ratio:.2f}x, "
        f"nonzero {nonzero_params}/{total_params})"
    )

    # 5) Short fine-tune for this stage
    optimizer, scheduler = build_optimizer_and_scheduler(
        model,
        base_lr=base_lr_ft,
        weight_decay=cfg.weight_decay,
        stage_epochs=stage_epochs,
    )

    best_stage_acc = 0.0
    for e in range(stage_epochs):
        epoch_global = (stage_idx - 1) * stage_epochs + e
        train_loss, train_acc = train_one_epoch(
            model, criterion, optimizer, trainloader, device, epoch=epoch_global
        )
        val_loss, val_acc = evaluate(
            model, criterion, testloader, device, epoch=f"S{stage_idx}_E{e}"
        )
        scheduler.step()

        if val_acc > best_stage_acc:
            best_stage_acc = val_acc

    print(
        f"End of stage {stage_idx}: "
        f"best_val_acc={best_stage_acc:.2f}%, "
        f"sparsity={sparsity*100:.2f}%, "
        f"compression={compression_ratio:.2f}x"
    )

    # 6) Save this stage's model
    stage_ckpt_path = f"mobilenetv2_cifar10_pruned_stage{stage_idx}.pth"
    torch.save(
        {
            "stage": stage_idx,
            "target_global_prune": target,
            "state_dict": model.state_dict(),
            "best_val_acc": best_stage_acc,
            "sparsity": sparsity,
            "compression": compression_ratio,
            "total_params": total_params,
            "nonzero_params": nonzero_params,
        },
        stage_ckpt_path,
    )
    print(f"Saved stage {stage_idx} checkpoint to: {stage_ckpt_path}")

    # 7) Log metrics for later analysis
    stage_results.append(
        {
            "stage": stage_idx,
            "target_global_prune": target,
            "best_val_acc": best_stage_acc,
            "acc_after_prune": acc_raw,
            "sparsity": sparsity,
            "compression": compression_ratio,
            "ckpt_path": stage_ckpt_path,
        }
    )

print("\n=== Compression vs Accuracy per Stage ===")
for r in stage_results:
    print(
        f"Stage {r['stage']}: "
        f"target={r['target_global_prune']:.2f}, "
        f"sparsity={r['sparsity']*100:.2f}%, "
        f"compression={r['compression']:.2f}x, "
        f"acc_after_prune={r['acc_after_prune']:.2f}%, "
        f"best_val_acc={r['best_val_acc']:.2f}%, "
        f"ckpt={r['ckpt_path']}"
    )

Using device: cuda


100%|██████████| 170M/170M [00:03<00:00, 46.6MB/s] 
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


Load ImageNet weights


100%|██████████| 13.6M/13.6M [00:00<00:00, 118MB/s] 


Dense model - total params: 2236682, nonzero: 2236682, sparsity: 0.00%
Epoch [BASELINE] VALID - Loss: 0.6330 | Top-1: 96.24%
Baseline accuracy: 96.24%
Epoch [CALIB_BASE] VALID - Loss: 0.6367 | Top-1: 96.39%
Baseline calib accuracy: 96.39%
Testing sensitivity for layer: features.1.conv.1
Epoch [CALIB_features.1.conv.1] VALID - Loss: 0.6934 | Top-1: 93.07%
  Acc with layer pruned: 93.07% (drop 3.32%)
Testing sensitivity for layer: features.2.conv.0.0
Epoch [CALIB_features.2.conv.0.0] VALID - Loss: 0.6392 | Top-1: 96.19%
  Acc with layer pruned: 96.19% (drop 0.20%)
Testing sensitivity for layer: features.2.conv.2
Epoch [CALIB_features.2.conv.2] VALID - Loss: 0.6462 | Top-1: 95.51%
  Acc with layer pruned: 95.51% (drop 0.88%)
Testing sensitivity for layer: features.3.conv.0.0
Epoch [CALIB_features.3.conv.0.0] VALID - Loss: 0.6375 | Top-1: 96.00%
  Acc with layer pruned: 96.00% (drop 0.39%)
Testing sensitivity for layer: features.3.conv.2
Epoch [CALIB_features.3.conv.2] VALID - Loss: 0.6358

In [23]:
# --- Final fine-tuning of best pruned model before quantization ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Rebuild the same architecture
model = create_mobilenetv2_cifar10(num_classes=10, pretrained=False)

ckpt = torch.load("mobilenetv2_cifar10_pruned_stage5.pth", map_location="cpu")
model.load_state_dict(ckpt["state_dict"], strict=True)
model = model.to(device)

# Recreate data loaders if needed
trainloader, testloader = get_cifar10_loaders(
    cfg.data_dir,
    batch_size=cfg.batch_size,
    num_workers=cfg.num_workers,
)

criterion = LabelSmoothingCrossEntropy(smoothing=cfg.label_smoothing)

fine_tune_epochs = 15       # 5–20 is typical
fine_tune_lr     = 0.005    # if original was 0.05

optimizer, scheduler = build_optimizer_and_scheduler(
    model,
    base_lr=fine_tune_lr,
    weight_decay=cfg.weight_decay,
    stage_epochs=fine_tune_epochs,
)

best_ft_acc = 0.0
best_ft_state = None

for epoch in range(fine_tune_epochs):
    train_loss, train_acc = train_one_epoch(
        model, criterion, optimizer, trainloader, device, epoch
    )
    val_loss, val_acc = evaluate(
        model, criterion, testloader, device, epoch=f"FT_{epoch}"
    )

    scheduler.step()

    if val_acc > best_ft_acc:
        best_ft_acc = val_acc
        best_ft_state = {k: v.cpu() for k, v in model.state_dict().items()}
        print(f"[FT] New best accuracy: {best_ft_acc:.2f}% at epoch {epoch}")

if best_ft_state is not None:
    torch.save(best_ft_state, "mobilenetv2_cifar10_pruned_finetuned.pth")
    print(f"Saved final pruned+finetuned model with acc={best_ft_acc:.2f}%")
else:
    # fallback if somehow no improvement happened
    torch.save(model.state_dict(), "mobilenetv2_cifar10_pruned_finetuned.pth")

Epoch [0] Step [100/391] Loss: 0.7011 | Top-1: 93.66%
Epoch [0] Step [200/391] Loss: 0.7015 | Top-1: 93.43%
Epoch [0] Step [300/391] Loss: 0.7036 | Top-1: 93.26%
Epoch [0] TRAIN - Loss: 0.7044 | Top-1: 93.25% | Time: 174.5s
Epoch [FT_0] VALID - Loss: 0.6589 | Top-1: 95.09%
[FT] New best accuracy: 95.09% at epoch 0
Epoch [1] Step [100/391] Loss: 0.7079 | Top-1: 93.16%
Epoch [1] Step [200/391] Loss: 0.7061 | Top-1: 93.19%
Epoch [1] Step [300/391] Loss: 0.7053 | Top-1: 93.16%
Epoch [1] TRAIN - Loss: 0.7053 | Top-1: 93.19% | Time: 180.1s
Epoch [FT_1] VALID - Loss: 0.6573 | Top-1: 95.14%
[FT] New best accuracy: 95.14% at epoch 1
Epoch [2] Step [100/391] Loss: 0.7005 | Top-1: 93.66%
Epoch [2] Step [200/391] Loss: 0.6950 | Top-1: 93.78%
Epoch [2] Step [300/391] Loss: 0.6981 | Top-1: 93.62%
Epoch [2] TRAIN - Loss: 0.6995 | Top-1: 93.58% | Time: 181.9s
Epoch [FT_2] VALID - Loss: 0.6554 | Top-1: 95.14%
Epoch [3] Step [100/391] Loss: 0.7002 | Top-1: 93.48%
Epoch [3] Step [200/391] Loss: 0.6982 | 