## Import

In [1]:
import os
import random

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

In [3]:
from neumeta.models import create_model_cifar10 as create_model
from neumeta.utils import (
    parse_args, print_omegaconf,
    load_checkpoint, save_checkpoint,
    set_seed,
    get_cifar10, 
    sample_coordinates, sample_subset, shuffle_coordinates_all,
    get_hypernetwork, get_optimizer, 
    sample_weights,
    weighted_regression_loss, validate_single, AverageMeter, EMA,
    sample_single_model, sample_merge_model,
)

## Functions

### 1 Find maximum dimension of the model

In [4]:
def find_max_dim(model_cls):
    """Find maximum dimension of the model"""
    # Get the learnable parameters of the model
    checkpoint = model_cls.learnable_parameter 

    # Set the maximum value to the length of the checkpoint
    max_value = len(checkpoint)

    # Iterate over the new model's weight
    for i, (k, tensor) in enumerate(checkpoint.items()):
        # Handle 2D tensors (e.g., weight matrices) 
        if len(tensor.shape) == 4:
            coords = [tensor.shape[0], tensor.shape[1]]
            max_value = max(max_value, max(coords))
        # Handle 1D tensors (e.g., biases)
        elif len(tensor.shape) == 1:
            max_value = max(max_value, tensor.shape[0])
    
    return max_value

### 2 Initialize wandb

In [5]:
def initialize_wandb(config):
    import time
    """
    Initializes Weights and Biases (wandb) with the given configuration.
    
    Args:
        configuration (dict): Configuration parameters for the run.
    """
    # Name the run using current time and configuration name
    run_name = f"{time.strftime('%Y%m%d%H%M%S')}-{config.experiment.name}"
    
    wandb.init(project="ninr-trial", name=run_name, config=dict(config), group='cifar10')

### 3 Initialize model dictionary

In [6]:
def init_model_dict(args, device):
    """
    Initializes a dictionary of models for each dimension in the given range, along with ground truth models for the starting dimension.

    Args:
        args: An object containing the arguments for initializing the models.

    Returns:
        dim_dict: A dictionary containing the models for each dimension, along with their corresponding coordinates, keys, indices, size, and ground truth models.
        gt_model_dict: A dictionary containing the ground truth models for the starting dimension.
    """
    dim_dict = {}
    gt_model_dict = {}
    
    # Create a model for each dimension in dimensions range
    for dim in args.dimensions.range:
        model_cls = create_model(args.model.type,
                                 hidden_dim=dim,
                                 path=args.model.pretrained_path,
                                 smooth=args.model.smooth).to(device)
        # Sample the coordinates, keys, indices, and the size for the model
        coords_tensor, keys_list, indices_list, size_list = sample_coordinates(model_cls)
        # Add the model, coordinates, keys, indices, size, and key mask to the dictionary
        dim_dict[f"{dim}"] = (model_cls, coords_tensor, keys_list, indices_list, size_list, None)

        # Print to makes line looked better
        print('\n')
        
        # If the dimension is the starting dimension (the dimension of pretrained_model), add the ground truth model to the dictionary
        if dim == args.dimensions.start:
            print(f"Loading model for dim {dim}")
            model_trained = create_model(args.model.type, 
                                         hidden_dim=dim, 
                                         path=args.model.pretrained_path, 
                                         smooth=args.model.smooth).to(device)
            model_trained.eval()
            gt_model_dict[f'{dim}'] = model_trained

    
    return dim_dict, gt_model_dict

### 4 Training function for target model of a random dimension

In [7]:
# Function to train the model for one epoch
def train_one_epoch(model, train_loader, optimizer, criterion, dim_dict, gt_model_dict, epoch_idx, ema=None, args=None, device='cpu'):
    # Set the model to training mode
    model.train()
    total_loss = 0.0

    # Initialize AverageMeter objects to track the losses
    losses = AverageMeter()
    cls_losses = AverageMeter()
    reg_losses = AverageMeter()
    reconstruct_losses = AverageMeter()

    # Iterate over the training data
    for batch_idx, (x, target) in enumerate(train_loader):
        # Zero the gradients
        optimizer.zero_grad()

        # Preprocess input
        # ------------------------------------------------------------------------------------------------------
        # Move the data to the device
        x, target = x.to(device), target.to(device)
        # Choose a random hidden dimension
        hidden_dim = random.choice(args.dimensions.range)
        # Get the model class, coordinates, keys, indices, size, and key mask for the chosen dimension
        model_cls, coords_tensor, keys_list, indices_list, size_list, key_mask = dim_dict[f"{hidden_dim}"]
        # Sample a subset the input tensor of the coordinates, keys, indices, size, and selected keys
        coords_tensor, keys_list, indices_list, size_list, selected_keys = sample_subset(coords_tensor,
                                                                                         keys_list,
                                                                                         indices_list,
                                                                                         size_list,
                                                                                         key_mask,
                                                                                         ratio=args.ratio)
        # Add noise to the coordinates if specified
        if args.training.coordinate_noise > 0.0:
            coords_tensor = coords_tensor + (torch.rand_like(coords_tensor) - 0.5) * args.training.coordinate_noise


        # Main task of hypernetwork and target network
        # ------------------------------------------------------------------------------------------------------
        # Sample the weights for the target model using hypernetwork
        model_cls, reconstructed_weights = sample_weights(model, model_cls,
                                                          coords_tensor, keys_list, indices_list, size_list, key_mask, selected_keys,
                                                          device=device, NORM=args.dimensions.norm)
        # Forward pass
        predict = model_cls(x)


        # Compute losses
        # ------------------------------------------------------------------------------------------------------
        # Compute classification loss
        cls_loss = criterion(predict, target) 
        # Compute regularization loss
        reg_loss = sum([torch.norm(w, p=2) for w in reconstructed_weights])
        # Compute reconstruction loss if ground truth model is available
        if f"{hidden_dim}" in gt_model_dict:
            gt_model = gt_model_dict[f"{hidden_dim}"]
            gt_selected_weights = [
                w for k, w in gt_model.learnable_parameter.items() if k in selected_keys]

            reconstruct_loss = weighted_regression_loss(
                reconstructed_weights, gt_selected_weights)
        else:
            reconstruct_loss = torch.tensor(0.0)
        # Compute the total loss
        loss = args.hyper_model.loss_weight.ce_weight * cls_loss + args.hyper_model.loss_weight.reg_weight * \
            reg_loss + args.hyper_model.loss_weight.recon_weight * reconstruct_loss


        # Compute gradients and update weights
        # ------------------------------------------------------------------------------------------------------
        # Zero the gradients of the updated weights
        for updated_weight in model_cls.parameters():
            updated_weight.grad = None

        # Compute the gradients of the reconstructed weights
        loss.backward(retain_graph=True)
        torch.autograd.backward(reconstructed_weights, [
                                w.grad for k, w in model_cls.named_parameters() if k in selected_keys])
        
        # Clip the gradients if specified
        if args.training.get('clip_grad', 0.0) > 0:
            torch.nn.utils.clip_grad_value_(
                model.parameters(), args.training.clip_grad)
            
        # Update the weights
        optimizer.step()

        # Update the EMA if specified
        if ema:
            ema.update()  # Update the EMA after each training step
        total_loss += loss.item()

        # Update the AverageMeter objects
        losses.update(loss.item())
        cls_losses.update(cls_loss.item())
        reg_losses.update(reg_loss.item())
        reconstruct_losses.update(reconstruct_loss.item())


        # Log (or plot) losses
        # ------------------------------------------------------------------------------------------------------
        # Log the losses and learning rate to wandb
        if batch_idx % args.experiment.log_interval == 0:
            wandb.log({
                "Loss": losses.avg,
                "Cls Loss": cls_losses.avg,
                "Reg Loss": reg_losses.avg,
                "Reconstruct Loss": reconstruct_losses.avg,
                "Learning rate": optimizer.param_groups[0]['lr']
            }, step=batch_idx + epoch_idx * len(train_loader))
            # Print the losses and learning rate
            print(
                f"Iteration {batch_idx}: Loss = {losses.avg:.4f}, Reg Loss = {reg_losses.avg:.4f}, Reconstruct Loss = {reconstruct_losses.avg:.4f}, Cls Loss = {cls_losses.avg:.4f}, Learning rate = {optimizer.param_groups[0]['lr']:.4e}")
    
    # Returns the training loss, structure of network in each dimension, and the original structure of pretrained network
    return losses.avg, dim_dict, gt_model_dict

## Main

### 0 Set device to GPU

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 1 Parsing arguments for inputs

In [9]:
CONFIG_PATH = 'neumeta/config/base_config_smooth_ratio_0.75.yaml'
RATIO = '0.75'

In [10]:
argv_train = ['--config', CONFIG_PATH, '--ratio', RATIO]
argv_test = ['--config', CONFIG_PATH, '--test']

In [11]:
args = parse_args(argv_train)  # Parse arguments
print_omegaconf(args)  # Print arguments

+--------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+
|                 Key                  |                                                                Value                                                                 |
+--------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+
|           experiment.name            |                                  ninr_resnet20_cifar10_32-64-4layer-5_base_config_smooth_ratio_0.75                                  |
|         experiment.recononly         |                                                                  0                                                                   |
|        experiment.num_epochs         |                                                                  30            

### 2 Get training and validation data (in dataloader format)

In [12]:
train_loader, val_loader = get_cifar10(args.training.batch_size, strong_transform=args.training.get('strong_aug', None))

### 3 Create target model

#### 3.0 Create the model

In [13]:
model = create_model(args.model.type,
                     hidden_dim=args.dimensions.start,
                     path=args.model.pretrained_path,
                     smooth=args.model.smooth).to(device)

Replace the last 2 block of layer3 with new block with hidden dim 64
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1321.4044189453125
Permuted TV original model: 1039.4769287109375


#### 3.1 Print the structure and shape of the model

In [14]:
model

CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): Identity()
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (2): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer2): Sequential(


In [15]:
for i, (k, tensor) in enumerate(model.learnable_parameter.items()):
    print(k, tensor.shape)

layer3.2.conv1.weight torch.Size([64, 64, 3, 3])
layer3.2.conv1.bias torch.Size([64])
layer3.2.conv2.weight torch.Size([64, 64, 3, 3])
layer3.2.conv2.bias torch.Size([64])


#### 3.2 The maximum dimension of the target model

In [16]:
# Print the maximum dimension of the model
print(f'Maximum DIM: {find_max_dim(model)}')

Maximum DIM: 64


#### 3.3 Validate the accuracy of pretrained network

In [17]:
# Validate the model for the starting dimension (its pretrained form)
val_loss, acc = validate_single(model, val_loader, nn.CrossEntropyLoss(), args=args)
print(f'Initial Permutated model Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc * 100:.2f}%')

100%|██████████| 157/157 [00:02<00:00, 69.76it/s]

Initial Permutated model Validation Loss: 0.2825, Validation Accuracy: 92.60%





In [18]:
# Get the learnable parameters of the model
checkpoint = model.learnable_parameter
# Get the number of parameters
number_param = len(checkpoint)

In [19]:
# Print the keys of the parameters and the number of parameters
print(f"Parameters keys: {model.keys}")
print(f"Number of parameters to be learned: {number_param}")

Parameters keys: ['layer3.2.conv1.weight', 'layer3.2.conv1.bias', 'layer3.2.conv2.weight', 'layer3.2.conv2.bias']
Number of parameters to be learned: 4


### 4 Create the hypernetwork

#### 4.0 Create the model

In [20]:
# Get the hypermodel
hyper_model = get_hypernetwork(args, number_param)

Hyper model type: mlp
num_freqs:  16 <class 'int'>


#### 4.1 Print model structure

In [21]:
hyper_model

NeRF_MLP_Compose(
  (positional_encoding): PositionalEncoding()
  (model): ModuleList(
    (0-3): 4 x NeRF_MLP_Residual_Scaled(
      (initial_layer): Linear(in_features=198, out_features=256, bias=True)
      (residual_blocks): ModuleList(
        (0-2): 3 x Linear(in_features=256, out_features=256, bias=True)
      )
      (scalars): ParameterList(
          (0): Parameter containing: [torch.float32 of size  (cuda:0)]
          (1): Parameter containing: [torch.float32 of size  (cuda:0)]
          (2): Parameter containing: [torch.float32 of size  (cuda:0)]
      )
      (act): ReLU(inplace=True)
      (output_layer): Linear(in_features=256, out_features=9, bias=True)
    )
  )
)

#### 4.2 Initialize EMA to track only a smooth version of the model weight

In [22]:
# Initialize the EMA
ema = EMA(hyper_model, decay=args.hyper_model.ema_decay)

### 5 Get Loss function, Optimizer, and Scheduler

In [23]:
criterion, val_criterion, optimizer, scheduler = get_optimizer(args, hyper_model)

In [24]:
print(f'Criterion: {criterion}\nVal_criterion: {val_criterion}\nOptimizer: {optimizer}\nScheduler: {scheduler}')

Criterion: CrossEntropyLoss()
Val_criterion: CrossEntropyLoss()
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)
Scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x0000025658BDA550>


### 6 Training loop

#### 6.1 Initialize training parameters

In [25]:
# Initialize the starting epoch and best accuracy
start_epoch = 0
best_acc = 0.0

#### 6.2 Directory to save the model

In [26]:
# Create the directory to save the model
os.makedirs(args.training.save_model_path, exist_ok=True)

#### 6.3 Resume training loop

In [27]:
if args.resume_from:
        print(f"Resuming from checkpoint: {args.resume_from}")
        checkpoint_info = load_checkpoint(args.resume_from, hyper_model, optimizer, ema)
        start_epoch = checkpoint_info['epoch']
        best_acc = checkpoint_info['best_acc']
        print(f"Resuming from epoch: {start_epoch}, best accuracy: {best_acc*100:.2f}%")
        # Note: If there are more elements to retrieve, do so here.

#### 6.4 Initialize wandb for plotting

In [28]:
# Initialize wandb
initialize_wandb(args)

[34m[1mwandb[0m: Currently logged in as: [33mefradosuryadi[0m ([33mefradosuryadi-universitas-indonesia[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


#### 6.5 Initialize model dictionary for each dimension and shuffle it

In [29]:
# Initialize model dictionary
dim_dict, gt_model_dict = init_model_dict(args, device)

Replace the last 2 block of layer3 with new block with hidden dim 32
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2153.698974609375
Permuted TV original model: 2040.7257080078125


Replace the last 2 block of layer3 with new block with hidden dim 33
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2183.05126953125
Permuted TV original model: 2074.080322265625


Replace the last 2 block of layer3 with new block with hidden dim 34
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2224.338134765625
Permuted TV original model: 2108.9326171875


Replace the last 2 block of layer3 with new block with hidden dim 35
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2273.907470703125
Permuted TV original model: 2155.6025390625


Replace the last 2 block of layer3 with new block with hidden dim 36
Loa

In [30]:
gt_model_dict

{'64': CifarResNet(
   (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (bn1): Identity()
   (relu): ReLU(inplace=True)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (bn1): Identity()
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (bn2): Identity()
     )
     (1): BasicBlock(
       (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (bn1): Identity()
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (bn2): Identity()
     )
     (2): BasicBlock(
       (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (bn1): Identity()
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (bn2): Identity()
   

In [31]:
dim_dict

{'32': (CifarResNet(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): Identity()
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): Identity()
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): Identity()
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): Identity()
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): Identity()
      )
      (2): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): Identity()
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    

In [32]:
dim_dict = shuffle_coordinates_all(dim_dict)
dim_dict

{'32': (CifarResNet(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): Identity()
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): Identity()
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): Identity()
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): Identity()
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): Identity()
      )
      (2): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): Identity()
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    

#### 6.6 Hypernetwork training loop

In [33]:
args.experiment.num_epochs

30

In [34]:
# Iterate over the epochs
for epoch in range(start_epoch, args.experiment.num_epochs):
    # Train the hypernetwork to generate a model with random dimension for one epoch
    train_loss, dim_dict, gt_model_dict = train_one_epoch(hyper_model, train_loader, optimizer, criterion, 
                                                          dim_dict, gt_model_dict, epoch_idx=epoch, ema=ema, 
                                                          args=args, device=device)
    # Step the scheduler
    scheduler.step()

    # Print the training loss and learning rate
    print(f"Epoch [{epoch+1}/{args.experiment.num_epochs}], Training Loss: {train_loss:.4f}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}")

    # If it's time to evaluate the model
    if (epoch + 1) % args.experiment.eval_interval == 0:
        # Apply EMA if it is specified
        if ema:
            ema.apply()  # Save the weights of original model created before training_loop
        
        # Sample the merged model (create model of same structure before training loop by using the hypernetwork)
        # And then test the performance of the hypernetwork by seeing how good it is in generating the weights
        model = sample_merge_model(hyper_model, model, args) 
        # Validate the merged model
        val_loss, acc = validate_single(model, val_loader, val_criterion, args=args)

        # If EMA is specified, restore the original weights
        if ema:
            ema.restore()  # Restore the original weights to the weights of the pretrained networks

        # Log the validation loss and accuracy to wandb
        wandb.log({
            "Validation Loss": val_loss,
            "Validation Accuracy": acc
        })
        # Print the validation loss and accuracy
        print(f"Epoch [{epoch+1}/{args.experiment.num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")
        print('\n\n')

        # Save the checkpoint if the accuracy is better than the previous best
        if acc > best_acc:
            best_acc = acc
            save_checkpoint(f"{args.training.save_model_path}/cifar10_nerf_best.pth",hyper_model,optimizer,ema,epoch,best_acc)
            print(f"Checkpoint saved at epoch {epoch} with accuracy: {best_acc*100:.2f}%")


Iteration 0: Loss = 3.1424, Reg Loss = 1.9832, Reconstruct Loss = 0.0000, Cls Loss = 3.1422, Learning rate = 1.0000e-03
Iteration 25: Loss = 2.2339, Reg Loss = 7.5404, Reconstruct Loss = 0.0103, Cls Loss = 2.2228, Learning rate = 1.0000e-03
Iteration 50: Loss = 2.0907, Reg Loss = 8.3492, Reconstruct Loss = 0.0053, Cls Loss = 2.0846, Learning rate = 1.0000e-03
Iteration 75: Loss = 2.0460, Reg Loss = 9.8459, Reconstruct Loss = 0.0035, Cls Loss = 2.0415, Learning rate = 1.0000e-03
Iteration 100: Loss = 2.0027, Reg Loss = 9.9381, Reconstruct Loss = 0.0067, Cls Loss = 1.9950, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.9758, Reg Loss = 9.2837, Reconstruct Loss = 0.0080, Cls Loss = 1.9668, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.9592, Reg Loss = 9.8834, Reconstruct Loss = 0.0072, Cls Loss = 1.9510, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.9432, Reg Loss = 11.0429, Reconstruct Loss = 0.0062, Cls Loss = 1.9359, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.9322

100%|██████████| 157/157 [00:01<00:00, 78.59it/s]


Epoch [1/30], Validation Loss: 1.8088, Validation Accuracy: 74.17%



Checkpoint saved at epoch 0 with accuracy: 74.17%
Iteration 0: Loss = 1.8387, Reg Loss = 5.7232, Reconstruct Loss = 0.0000, Cls Loss = 1.8381, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8510, Reg Loss = 7.5815, Reconstruct Loss = 0.0040, Cls Loss = 1.8462, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8523, Reg Loss = 6.4217, Reconstruct Loss = 0.0020, Cls Loss = 1.8496, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8460, Reg Loss = 6.2954, Reconstruct Loss = 0.0020, Cls Loss = 1.8434, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8471, Reg Loss = 6.8010, Reconstruct Loss = 0.0018, Cls Loss = 1.8446, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8476, Reg Loss = 7.0248, Reconstruct Loss = 0.0015, Cls Loss = 1.8454, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8461, Reg Loss = 7.0237, Reconstruct Loss = 0.0015, Cls Loss = 1.8439, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8456, R

100%|██████████| 157/157 [00:01<00:00, 81.32it/s]


Epoch [2/30], Validation Loss: 1.8083, Validation Accuracy: 74.76%



Checkpoint saved at epoch 1 with accuracy: 74.76%
Iteration 0: Loss = 1.8415, Reg Loss = 7.2183, Reconstruct Loss = 0.0000, Cls Loss = 1.8408, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8545, Reg Loss = 5.1104, Reconstruct Loss = 0.0050, Cls Loss = 1.8490, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8498, Reg Loss = 4.6401, Reconstruct Loss = 0.0034, Cls Loss = 1.8459, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8453, Reg Loss = 4.3851, Reconstruct Loss = 0.0025, Cls Loss = 1.8425, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8453, Reg Loss = 4.1825, Reconstruct Loss = 0.0018, Cls Loss = 1.8430, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8453, Reg Loss = 3.9719, Reconstruct Loss = 0.0015, Cls Loss = 1.8434, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8442, Reg Loss = 3.7910, Reconstruct Loss = 0.0015, Cls Loss = 1.8422, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8436, R

100%|██████████| 157/157 [00:03<00:00, 44.72it/s]


Epoch [3/30], Validation Loss: 1.8075, Validation Accuracy: 75.78%



Checkpoint saved at epoch 2 with accuracy: 75.78%
Iteration 0: Loss = 1.8446, Reg Loss = 1.3697, Reconstruct Loss = 0.0000, Cls Loss = 1.8445, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8425, Reg Loss = 2.4610, Reconstruct Loss = 0.0015, Cls Loss = 1.8408, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8391, Reg Loss = 2.8179, Reconstruct Loss = 0.0007, Cls Loss = 1.8381, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8422, Reg Loss = 3.0278, Reconstruct Loss = 0.0005, Cls Loss = 1.8414, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8422, Reg Loss = 2.9033, Reconstruct Loss = 0.0004, Cls Loss = 1.8415, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8409, Reg Loss = 2.7032, Reconstruct Loss = 0.0003, Cls Loss = 1.8403, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8429, Reg Loss = 2.5957, Reconstruct Loss = 0.0003, Cls Loss = 1.8424, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8428, R

100%|██████████| 157/157 [00:02<00:00, 70.04it/s]


Epoch [4/30], Validation Loss: 1.8075, Validation Accuracy: 75.89%



Checkpoint saved at epoch 3 with accuracy: 75.89%
Iteration 0: Loss = 1.7981, Reg Loss = 96.6273, Reconstruct Loss = 0.0000, Cls Loss = 1.7885, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8533, Reg Loss = 69.0902, Reconstruct Loss = 0.0018, Cls Loss = 1.8446, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8500, Reg Loss = 65.0770, Reconstruct Loss = 0.0009, Cls Loss = 1.8425, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8464, Reg Loss = 61.0981, Reconstruct Loss = 0.0006, Cls Loss = 1.8397, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8814, Reg Loss = 52.1071, Reconstruct Loss = 0.0360, Cls Loss = 1.8402, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8739, Reg Loss = 45.9115, Reconstruct Loss = 0.0289, Cls Loss = 1.8405, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8688, Reg Loss = 42.1379, Reconstruct Loss = 0.0244, Cls Loss = 1.8401, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.

100%|██████████| 157/157 [00:02<00:00, 68.64it/s]


Epoch [5/30], Validation Loss: 1.8076, Validation Accuracy: 75.79%



Iteration 0: Loss = 1.8626, Reg Loss = 14.3502, Reconstruct Loss = 0.0000, Cls Loss = 1.8612, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8418, Reg Loss = 7.1805, Reconstruct Loss = 0.0000, Cls Loss = 1.8411, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8394, Reg Loss = 7.8637, Reconstruct Loss = 0.0000, Cls Loss = 1.8386, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8412, Reg Loss = 8.0574, Reconstruct Loss = 0.0005, Cls Loss = 1.8398, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8408, Reg Loss = 8.1832, Reconstruct Loss = 0.0004, Cls Loss = 1.8396, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8409, Reg Loss = 8.3258, Reconstruct Loss = 0.0011, Cls Loss = 1.8389, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8425, Reg Loss = 7.9382, Reconstruct Loss = 0.0017, Cls Loss = 1.8400, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8421, Reg Loss = 7.9368, Reconstruct Loss = 0.0014, Cls 

100%|██████████| 157/157 [00:02<00:00, 71.29it/s]


Epoch [6/30], Validation Loss: 1.8076, Validation Accuracy: 75.87%



Iteration 0: Loss = 1.8724, Reg Loss = 2.6406, Reconstruct Loss = 0.0000, Cls Loss = 1.8721, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8414, Reg Loss = 3.4870, Reconstruct Loss = 0.0033, Cls Loss = 1.8377, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8427, Reg Loss = 6.8423, Reconstruct Loss = 0.0027, Cls Loss = 1.8393, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8541, Reg Loss = 8.9667, Reconstruct Loss = 0.0116, Cls Loss = 1.8416, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8545, Reg Loss = 9.2595, Reconstruct Loss = 0.0106, Cls Loss = 1.8430, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8536, Reg Loss = 10.2210, Reconstruct Loss = 0.0092, Cls Loss = 1.8433, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8513, Reg Loss = 10.8365, Reconstruct Loss = 0.0083, Cls Loss = 1.8420, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8495, Reg Loss = 10.6031, Reconstruct Loss = 0.0071, Cl

100%|██████████| 157/157 [00:03<00:00, 41.33it/s]


Epoch [7/30], Validation Loss: 1.8076, Validation Accuracy: 75.83%



Iteration 0: Loss = 2.3428, Reg Loss = 27.6607, Reconstruct Loss = 0.4042, Cls Loss = 1.9358, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8556, Reg Loss = 20.0625, Reconstruct Loss = 0.0155, Cls Loss = 1.8380, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8568, Reg Loss = 23.3708, Reconstruct Loss = 0.0097, Cls Loss = 1.8447, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8532, Reg Loss = 23.4931, Reconstruct Loss = 0.0075, Cls Loss = 1.8433, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8512, Reg Loss = 23.6573, Reconstruct Loss = 0.0058, Cls Loss = 1.8431, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8503, Reg Loss = 23.0378, Reconstruct Loss = 0.0046, Cls Loss = 1.8433, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8494, Reg Loss = 22.3794, Reconstruct Loss = 0.0043, Cls Loss = 1.8428, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8485, Reg Loss = 22.2548, Reconstruct Loss = 0.003

100%|██████████| 157/157 [00:02<00:00, 72.44it/s]


Epoch [8/30], Validation Loss: 1.8076, Validation Accuracy: 75.85%



Iteration 0: Loss = 1.8638, Reg Loss = 14.3084, Reconstruct Loss = 0.0000, Cls Loss = 1.8624, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8424, Reg Loss = 11.8150, Reconstruct Loss = 0.0018, Cls Loss = 1.8394, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8411, Reg Loss = 11.0557, Reconstruct Loss = 0.0009, Cls Loss = 1.8391, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8406, Reg Loss = 10.6558, Reconstruct Loss = 0.0009, Cls Loss = 1.8386, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8396, Reg Loss = 10.8748, Reconstruct Loss = 0.0016, Cls Loss = 1.8368, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8407, Reg Loss = 11.7707, Reconstruct Loss = 0.0014, Cls Loss = 1.8381, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8409, Reg Loss = 11.7633, Reconstruct Loss = 0.0012, Cls Loss = 1.8385, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8415, Reg Loss = 11.6488, Reconstruct Loss = 0.001

100%|██████████| 157/157 [00:02<00:00, 72.49it/s]


Epoch [9/30], Validation Loss: 1.8076, Validation Accuracy: 75.84%



Iteration 0: Loss = 1.8849, Reg Loss = 9.4937, Reconstruct Loss = 0.0000, Cls Loss = 1.8840, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8401, Reg Loss = 7.0943, Reconstruct Loss = 0.0000, Cls Loss = 1.8394, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8362, Reg Loss = 6.4617, Reconstruct Loss = 0.0000, Cls Loss = 1.8355, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8370, Reg Loss = 5.8128, Reconstruct Loss = 0.0000, Cls Loss = 1.8364, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8404, Reg Loss = 6.0805, Reconstruct Loss = 0.0011, Cls Loss = 1.8387, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8392, Reg Loss = 6.7068, Reconstruct Loss = 0.0009, Cls Loss = 1.8376, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8406, Reg Loss = 6.8585, Reconstruct Loss = 0.0011, Cls Loss = 1.8389, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8395, Reg Loss = 6.5536, Reconstruct Loss = 0.0009, Cls L

100%|██████████| 157/157 [00:02<00:00, 71.57it/s]


Epoch [10/30], Validation Loss: 1.8076, Validation Accuracy: 75.84%



Iteration 0: Loss = 1.8426, Reg Loss = 7.8919, Reconstruct Loss = 0.0000, Cls Loss = 1.8418, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8439, Reg Loss = 8.7109, Reconstruct Loss = 0.0000, Cls Loss = 1.8430, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8435, Reg Loss = 9.3980, Reconstruct Loss = 0.0000, Cls Loss = 1.8425, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8432, Reg Loss = 8.5412, Reconstruct Loss = 0.0006, Cls Loss = 1.8417, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8434, Reg Loss = 8.4033, Reconstruct Loss = 0.0005, Cls Loss = 1.8421, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8450, Reg Loss = 12.0135, Reconstruct Loss = 0.0004, Cls Loss = 1.8434, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8452, Reg Loss = 25.2276, Reconstruct Loss = 0.0003, Cls Loss = 1.8424, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8741, Reg Loss = 33.2339, Reconstruct Loss = 0.0294, C

100%|██████████| 157/157 [00:01<00:00, 79.95it/s]


Epoch [11/30], Validation Loss: 1.8076, Validation Accuracy: 75.87%



Iteration 0: Loss = 1.8995, Reg Loss = 15.1237, Reconstruct Loss = 0.0000, Cls Loss = 1.8980, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8502, Reg Loss = 10.5084, Reconstruct Loss = 0.0015, Cls Loss = 1.8476, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8444, Reg Loss = 10.2410, Reconstruct Loss = 0.0025, Cls Loss = 1.8409, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8451, Reg Loss = 10.3662, Reconstruct Loss = 0.0018, Cls Loss = 1.8423, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8443, Reg Loss = 10.0876, Reconstruct Loss = 0.0021, Cls Loss = 1.8412, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8430, Reg Loss = 10.0248, Reconstruct Loss = 0.0021, Cls Loss = 1.8399, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8423, Reg Loss = 9.9171, Reconstruct Loss = 0.0021, Cls Loss = 1.8393, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8426, Reg Loss = 9.9144, Reconstruct Loss = 0.0018

100%|██████████| 157/157 [00:02<00:00, 71.81it/s]


Epoch [12/30], Validation Loss: 1.8076, Validation Accuracy: 75.86%



Iteration 0: Loss = 1.8394, Reg Loss = 8.3193, Reconstruct Loss = 0.0000, Cls Loss = 1.8386, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8403, Reg Loss = 9.6621, Reconstruct Loss = 0.0000, Cls Loss = 1.8394, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8411, Reg Loss = 8.4146, Reconstruct Loss = 0.0000, Cls Loss = 1.8403, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8396, Reg Loss = 7.5329, Reconstruct Loss = 0.0000, Cls Loss = 1.8388, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8396, Reg Loss = 7.3067, Reconstruct Loss = 0.0000, Cls Loss = 1.8388, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8403, Reg Loss = 7.1044, Reconstruct Loss = 0.0001, Cls Loss = 1.8395, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8398, Reg Loss = 6.8106, Reconstruct Loss = 0.0001, Cls Loss = 1.8390, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8394, Reg Loss = 6.6179, Reconstruct Loss = 0.0003, Cls 

100%|██████████| 157/157 [00:01<00:00, 79.34it/s]


Epoch [13/30], Validation Loss: 1.8076, Validation Accuracy: 75.88%



Iteration 0: Loss = 1.8162, Reg Loss = 11.0947, Reconstruct Loss = 0.0000, Cls Loss = 1.8151, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8464, Reg Loss = 7.6282, Reconstruct Loss = 0.0018, Cls Loss = 1.8438, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8438, Reg Loss = 7.7441, Reconstruct Loss = 0.0027, Cls Loss = 1.8403, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8421, Reg Loss = 9.0394, Reconstruct Loss = 0.0018, Cls Loss = 1.8394, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8411, Reg Loss = 8.6388, Reconstruct Loss = 0.0021, Cls Loss = 1.8382, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8427, Reg Loss = 10.1867, Reconstruct Loss = 0.0022, Cls Loss = 1.8395, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8429, Reg Loss = 10.9913, Reconstruct Loss = 0.0018, Cls Loss = 1.8400, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8440, Reg Loss = 10.8711, Reconstruct Loss = 0.0023, 

100%|██████████| 157/157 [00:04<00:00, 35.62it/s]


Epoch [14/30], Validation Loss: 1.8076, Validation Accuracy: 75.83%



Iteration 0: Loss = 1.8600, Reg Loss = 38.4461, Reconstruct Loss = 0.0000, Cls Loss = 1.8562, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8504, Reg Loss = 31.7706, Reconstruct Loss = 0.0000, Cls Loss = 1.8472, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8494, Reg Loss = 28.5522, Reconstruct Loss = 0.0020, Cls Loss = 1.8445, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8464, Reg Loss = 29.5726, Reconstruct Loss = 0.0014, Cls Loss = 1.8421, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8459, Reg Loss = 27.6473, Reconstruct Loss = 0.0014, Cls Loss = 1.8417, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8439, Reg Loss = 24.8129, Reconstruct Loss = 0.0011, Cls Loss = 1.8403, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8430, Reg Loss = 22.2967, Reconstruct Loss = 0.0016, Cls Loss = 1.8392, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8434, Reg Loss = 19.8764, Reconstruct Loss = 0.00

100%|██████████| 157/157 [00:02<00:00, 72.11it/s]


Epoch [15/30], Validation Loss: 1.8076, Validation Accuracy: 75.85%



Iteration 0: Loss = 1.8406, Reg Loss = 3.0446, Reconstruct Loss = 0.0000, Cls Loss = 1.8403, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8409, Reg Loss = 5.1843, Reconstruct Loss = 0.0018, Cls Loss = 1.8386, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8414, Reg Loss = 17.5799, Reconstruct Loss = 0.0009, Cls Loss = 1.8387, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8435, Reg Loss = 24.6148, Reconstruct Loss = 0.0013, Cls Loss = 1.8397, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8428, Reg Loss = 28.4198, Reconstruct Loss = 0.0010, Cls Loss = 1.8390, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8418, Reg Loss = 29.4191, Reconstruct Loss = 0.0013, Cls Loss = 1.8375, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8422, Reg Loss = 29.5157, Reconstruct Loss = 0.0011, Cls Loss = 1.8381, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8432, Reg Loss = 28.9307, Reconstruct Loss = 0.0014

100%|██████████| 157/157 [00:04<00:00, 36.69it/s]


Epoch [16/30], Validation Loss: 1.8076, Validation Accuracy: 75.87%



Iteration 0: Loss = 1.8404, Reg Loss = 15.6899, Reconstruct Loss = 0.0000, Cls Loss = 1.8388, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8510, Reg Loss = 9.8871, Reconstruct Loss = 0.0000, Cls Loss = 1.8500, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8488, Reg Loss = 8.2258, Reconstruct Loss = 0.0000, Cls Loss = 1.8480, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8510, Reg Loss = 7.1965, Reconstruct Loss = 0.0019, Cls Loss = 1.8484, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8487, Reg Loss = 6.4060, Reconstruct Loss = 0.0014, Cls Loss = 1.8467, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8471, Reg Loss = 6.3646, Reconstruct Loss = 0.0011, Cls Loss = 1.8453, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8455, Reg Loss = 6.1550, Reconstruct Loss = 0.0010, Cls Loss = 1.8439, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8442, Reg Loss = 5.7725, Reconstruct Loss = 0.0008, Cls

100%|██████████| 157/157 [00:02<00:00, 71.22it/s]


Epoch [17/30], Validation Loss: 1.8076, Validation Accuracy: 75.85%



Iteration 0: Loss = 1.8474, Reg Loss = 61.2100, Reconstruct Loss = 0.0000, Cls Loss = 1.8413, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8511, Reg Loss = 52.2476, Reconstruct Loss = 0.0037, Cls Loss = 1.8422, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8545, Reg Loss = 46.9738, Reconstruct Loss = 0.0029, Cls Loss = 1.8469, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8525, Reg Loss = 45.5446, Reconstruct Loss = 0.0020, Cls Loss = 1.8459, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8505, Reg Loss = 43.7018, Reconstruct Loss = 0.0015, Cls Loss = 1.8446, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8492, Reg Loss = 43.8445, Reconstruct Loss = 0.0012, Cls Loss = 1.8436, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8479, Reg Loss = 42.9672, Reconstruct Loss = 0.0014, Cls Loss = 1.8421, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8463, Reg Loss = 40.6908, Reconstruct Loss = 0.00

100%|██████████| 157/157 [00:02<00:00, 78.07it/s]


Epoch [18/30], Validation Loss: 1.8076, Validation Accuracy: 75.87%



Iteration 0: Loss = 1.8476, Reg Loss = 6.1766, Reconstruct Loss = 0.0000, Cls Loss = 1.8470, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8488, Reg Loss = 49.8320, Reconstruct Loss = 0.0101, Cls Loss = 1.8337, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8499, Reg Loss = 46.0146, Reconstruct Loss = 0.0052, Cls Loss = 1.8401, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8479, Reg Loss = 43.5434, Reconstruct Loss = 0.0035, Cls Loss = 1.8401, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8459, Reg Loss = 38.4674, Reconstruct Loss = 0.0026, Cls Loss = 1.8394, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8446, Reg Loss = 34.0030, Reconstruct Loss = 0.0028, Cls Loss = 1.8383, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8870, Reg Loss = 32.1755, Reconstruct Loss = 0.0024, Cls Loss = 1.8814, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.9089, Reg Loss = 42.9825, Reconstruct Loss = 0.028

100%|██████████| 157/157 [00:02<00:00, 77.81it/s]


Epoch [19/30], Validation Loss: 1.8076, Validation Accuracy: 75.88%



Iteration 0: Loss = 1.8284, Reg Loss = 5.7844, Reconstruct Loss = 0.0000, Cls Loss = 1.8278, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8472, Reg Loss = 25.9476, Reconstruct Loss = 0.0015, Cls Loss = 1.8432, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8445, Reg Loss = 26.8854, Reconstruct Loss = 0.0007, Cls Loss = 1.8410, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8469, Reg Loss = 25.8827, Reconstruct Loss = 0.0010, Cls Loss = 1.8433, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8449, Reg Loss = 24.5328, Reconstruct Loss = 0.0012, Cls Loss = 1.8412, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8450, Reg Loss = 22.9266, Reconstruct Loss = 0.0014, Cls Loss = 1.8413, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8459, Reg Loss = 22.0141, Reconstruct Loss = 0.0015, Cls Loss = 1.8422, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8471, Reg Loss = 23.2473, Reconstruct Loss = 0.001

100%|██████████| 157/157 [00:01<00:00, 79.08it/s]


Epoch [20/30], Validation Loss: 1.8076, Validation Accuracy: 75.85%



Iteration 0: Loss = 1.8327, Reg Loss = 38.7675, Reconstruct Loss = 0.0000, Cls Loss = 1.8288, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8424, Reg Loss = 31.8136, Reconstruct Loss = 0.0006, Cls Loss = 1.8386, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8440, Reg Loss = 31.2274, Reconstruct Loss = 0.0003, Cls Loss = 1.8406, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8444, Reg Loss = 28.8277, Reconstruct Loss = 0.0009, Cls Loss = 1.8406, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8445, Reg Loss = 27.7482, Reconstruct Loss = 0.0008, Cls Loss = 1.8408, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8460, Reg Loss = 25.5736, Reconstruct Loss = 0.0007, Cls Loss = 1.8428, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8439, Reg Loss = 23.5397, Reconstruct Loss = 0.0006, Cls Loss = 1.8410, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8423, Reg Loss = 21.5652, Reconstruct Loss = 0.00

100%|██████████| 157/157 [00:02<00:00, 76.72it/s]


Epoch [21/30], Validation Loss: 1.8076, Validation Accuracy: 75.88%



Iteration 0: Loss = 1.8825, Reg Loss = 148.0004, Reconstruct Loss = 0.0000, Cls Loss = 1.8677, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8587, Reg Loss = 116.3151, Reconstruct Loss = 0.0031, Cls Loss = 1.8440, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8591, Reg Loss = 116.3879, Reconstruct Loss = 0.0060, Cls Loss = 1.8415, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8573, Reg Loss = 112.5553, Reconstruct Loss = 0.0040, Cls Loss = 1.8420, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8563, Reg Loss = 103.5994, Reconstruct Loss = 0.0037, Cls Loss = 1.8422, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8521, Reg Loss = 89.8931, Reconstruct Loss = 0.0031, Cls Loss = 1.8399, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8515, Reg Loss = 84.6794, Reconstruct Loss = 0.0029, Cls Loss = 1.8401, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8514, Reg Loss = 81.1949, Reconstruct Loss =

100%|██████████| 157/157 [00:01<00:00, 79.42it/s]


Epoch [22/30], Validation Loss: 1.8076, Validation Accuracy: 75.85%



Iteration 0: Loss = 1.8038, Reg Loss = 25.7002, Reconstruct Loss = 0.0000, Cls Loss = 1.8013, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8581, Reg Loss = 62.9847, Reconstruct Loss = 0.0026, Cls Loss = 1.8492, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8616, Reg Loss = 113.1872, Reconstruct Loss = 0.0044, Cls Loss = 1.8459, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8624, Reg Loss = 126.3426, Reconstruct Loss = 0.0055, Cls Loss = 1.8443, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8621, Reg Loss = 132.5842, Reconstruct Loss = 0.0056, Cls Loss = 1.8433, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8619, Reg Loss = 140.5706, Reconstruct Loss = 0.0045, Cls Loss = 1.8434, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8629, Reg Loss = 148.1024, Reconstruct Loss = 0.0046, Cls Loss = 1.8435, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8620, Reg Loss = 153.8707, Reconstruct Loss 

100%|██████████| 157/157 [00:02<00:00, 77.72it/s]


Epoch [23/30], Validation Loss: 1.8076, Validation Accuracy: 75.87%



Iteration 0: Loss = 1.8639, Reg Loss = 18.9045, Reconstruct Loss = 0.0000, Cls Loss = 1.8620, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8439, Reg Loss = 43.9247, Reconstruct Loss = 0.0000, Cls Loss = 1.8395, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8454, Reg Loss = 39.3627, Reconstruct Loss = 0.0000, Cls Loss = 1.8415, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8435, Reg Loss = 34.1177, Reconstruct Loss = 0.0000, Cls Loss = 1.8401, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8456, Reg Loss = 28.8804, Reconstruct Loss = 0.0000, Cls Loss = 1.8427, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8441, Reg Loss = 25.5106, Reconstruct Loss = 0.0007, Cls Loss = 1.8408, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8460, Reg Loss = 38.2681, Reconstruct Loss = 0.0006, Cls Loss = 1.8415, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8479, Reg Loss = 49.8952, Reconstruct Loss = 0.00

100%|██████████| 157/157 [00:02<00:00, 77.73it/s]


Epoch [24/30], Validation Loss: 1.8076, Validation Accuracy: 75.87%



Iteration 0: Loss = 1.8262, Reg Loss = 356.7827, Reconstruct Loss = 0.0000, Cls Loss = 1.7905, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8653, Reg Loss = 290.0023, Reconstruct Loss = 0.0095, Cls Loss = 1.8267, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8727, Reg Loss = 247.6016, Reconstruct Loss = 0.0111, Cls Loss = 1.8368, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8684, Reg Loss = 207.7317, Reconstruct Loss = 0.0083, Cls Loss = 1.8393, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8644, Reg Loss = 168.7387, Reconstruct Loss = 0.0063, Cls Loss = 1.8412, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8608, Reg Loss = 140.9010, Reconstruct Loss = 0.0052, Cls Loss = 1.8416, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8597, Reg Loss = 134.0180, Reconstruct Loss = 0.0043, Cls Loss = 1.8420, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8571, Reg Loss = 126.3082, Reconstruct Los

100%|██████████| 157/157 [00:02<00:00, 75.66it/s]


Epoch [25/30], Validation Loss: 1.8076, Validation Accuracy: 75.86%



Iteration 0: Loss = 2.0552, Reg Loss = 1533.4734, Reconstruct Loss = 0.0000, Cls Loss = 1.9018, Learning rate = 1.0000e-03
Iteration 25: Loss = 2.0981, Reg Loss = 1155.5295, Reconstruct Loss = 0.1337, Cls Loss = 1.8488, Learning rate = 1.0000e-03
Iteration 50: Loss = 2.0160, Reg Loss = 1053.7751, Reconstruct Loss = 0.0682, Cls Loss = 1.8425, Learning rate = 1.0000e-03
Iteration 75: Loss = 2.0529, Reg Loss = 987.8526, Reconstruct Loss = 0.1122, Cls Loss = 1.8419, Learning rate = 1.0000e-03
Iteration 100: Loss = 2.0366, Reg Loss = 911.4081, Reconstruct Loss = 0.1022, Cls Loss = 1.8433, Learning rate = 1.0000e-03
Iteration 125: Loss = 2.0195, Reg Loss = 858.2988, Reconstruct Loss = 0.0923, Cls Loss = 1.8414, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.9971, Reg Loss = 799.1501, Reconstruct Loss = 0.0770, Cls Loss = 1.8402, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.9902, Reg Loss = 757.4726, Reconstruct 

100%|██████████| 157/157 [00:01<00:00, 79.61it/s]


Epoch [26/30], Validation Loss: 1.8076, Validation Accuracy: 75.86%



Iteration 0: Loss = 1.8445, Reg Loss = 29.8425, Reconstruct Loss = 0.0000, Cls Loss = 1.8416, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8760, Reg Loss = 73.0274, Reconstruct Loss = 0.0018, Cls Loss = 1.8669, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8652, Reg Loss = 115.5362, Reconstruct Loss = 0.0009, Cls Loss = 1.8527, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8643, Reg Loss = 136.4527, Reconstruct Loss = 0.0006, Cls Loss = 1.8501, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8655, Reg Loss = 136.9198, Reconstruct Loss = 0.0016, Cls Loss = 1.8502, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8633, Reg Loss = 129.0122, Reconstruct Loss = 0.0013, Cls Loss = 1.8491, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8822, Reg Loss = 118.9145, Reconstruct Loss = 0.0081, Cls Loss = 1.8623, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8794, Reg Loss = 118.5986, Reconstruct Loss 

100%|██████████| 157/157 [00:01<00:00, 80.44it/s]


Epoch [27/30], Validation Loss: 1.8076, Validation Accuracy: 75.83%



Iteration 0: Loss = 1.8089, Reg Loss = 28.5984, Reconstruct Loss = 0.0000, Cls Loss = 1.8060, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8426, Reg Loss = 23.5065, Reconstruct Loss = 0.0018, Cls Loss = 1.8384, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8408, Reg Loss = 23.5175, Reconstruct Loss = 0.0019, Cls Loss = 1.8365, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.9907, Reg Loss = 71.8490, Reconstruct Loss = 0.0022, Cls Loss = 1.9813, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.9748, Reg Loss = 165.0581, Reconstruct Loss = 0.0099, Cls Loss = 1.9485, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.9595, Reg Loss = 244.4697, Reconstruct Loss = 0.0079, Cls Loss = 1.9272, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.9508, Reg Loss = 264.2623, Reconstruct Loss = 0.0066, Cls Loss = 1.9178, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.9434, Reg Loss = 299.1847, Reconstruct Loss = 

100%|██████████| 157/157 [00:02<00:00, 69.16it/s]


Epoch [28/30], Validation Loss: 1.8076, Validation Accuracy: 75.86%



Iteration 0: Loss = 1.8618, Reg Loss = 205.0118, Reconstruct Loss = 0.0000, Cls Loss = 1.8413, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8673, Reg Loss = 153.3267, Reconstruct Loss = 0.0060, Cls Loss = 1.8459, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8633, Reg Loss = 145.0163, Reconstruct Loss = 0.0031, Cls Loss = 1.8457, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8617, Reg Loss = 138.1299, Reconstruct Loss = 0.0028, Cls Loss = 1.8451, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8591, Reg Loss = 117.6287, Reconstruct Loss = 0.0026, Cls Loss = 1.8448, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8609, Reg Loss = 117.5028, Reconstruct Loss = 0.0036, Cls Loss = 1.8455, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8602, Reg Loss = 127.6202, Reconstruct Loss = 0.0030, Cls Loss = 1.8444, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8615, Reg Loss = 138.9988, Reconstruct Los

100%|██████████| 157/157 [00:01<00:00, 79.58it/s]


Epoch [29/30], Validation Loss: 1.8076, Validation Accuracy: 75.86%



Iteration 0: Loss = 1.8384, Reg Loss = 13.7632, Reconstruct Loss = 0.0000, Cls Loss = 1.8370, Learning rate = 1.0000e-03
Iteration 25: Loss = 1.8443, Reg Loss = 30.3356, Reconstruct Loss = 0.0019, Cls Loss = 1.8393, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8584, Reg Loss = 114.1450, Reconstruct Loss = 0.0010, Cls Loss = 1.8460, Learning rate = 1.0000e-03
Iteration 75: Loss = 1.8732, Reg Loss = 211.9988, Reconstruct Loss = 0.0094, Cls Loss = 1.8425, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8835, Reg Loss = 250.3651, Reconstruct Loss = 0.0172, Cls Loss = 1.8412, Learning rate = 1.0000e-03
Iteration 125: Loss = 1.8810, Reg Loss = 268.0617, Reconstruct Loss = 0.0138, Cls Loss = 1.8404, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8824, Reg Loss = 267.1113, Reconstruct Loss = 0.0136, Cls Loss = 1.8421, Learning rate = 1.0000e-03
Iteration 175: Loss = 1.8833, Reg Loss = 266.2435, Reconstruct Loss 

100%|██████████| 157/157 [00:02<00:00, 76.43it/s]

Epoch [30/30], Validation Loss: 1.8076, Validation Accuracy: 75.88%








In [35]:
# End the wandb tracking
wandb.finish()

0,1
Cls Loss,▂▂▁▁▂▂▂█▂▂▂▂▁▁▂▁▁▂▂▂▂▂▂▁▂▁▇▂▂▂▂▂▂▂▃▁▂▂▃▃
Learning rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss,▁▂▁▁▁▁▁▃▂▁▃▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▂▄▂██▄▆▆▆▃█
Reconstruct Loss,▂▁▁▁▁▄▁▁▃▂▁▁▁▁▅▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▂▂█▆▂▂
Reg Loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▂▂█▂▁▃▄▃▂▃▃▄
Validation Accuracy,▁▃████████████████████████████
Validation Loss,█▅▁▁▁▁▂▁▂▂▁▁▁▂▁▁▁▁▁▂▁▂▁▁▁▁▂▂▂▁

0,1
Cls Loss,1.86363
Learning rate,0.001
Loss,1.91921
Reconstruct Loss,0.017
Reg Loss,385.83987
Validation Accuracy,0.7588
Validation Loss,1.8076


### 7 Testing loop

In [36]:
for hidden_dim in range(16, 65):
    # Create a model for this given dimension
    model = create_model(args.model.type,
                         hidden_dim=hidden_dim,
                         path=args.model.pretrained_path,
                         smooth=args.model.smooth).to(device)
    
    # If EMA is specified, apply it
    if ema:
        print('Applying EMA')
        ema.apply()

    # Sample the merged model
    accumulated_model = sample_merge_model(hyper_model, model, args, K=100)

    # Validate the merged model
    val_loss, acc = validate_single(accumulated_model, val_loader, val_criterion, args=args)

    # If EMA is specified, restore the original weights after applying EMA
    if ema:
        ema.restore()  # Restore the original weights after applying 
        
    # Save the model
    save_name = os.path.join(args.training.save_model_path, f"cifar10_{accumulated_model.__class__.__name__}_dim{hidden_dim}_single.pth")
    torch.save(accumulated_model.state_dict(),save_name)

    # Print the results
    print(f"Test using model {args.model}: hidden_dim {hidden_dim}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")
    print('\n')

    # Define the directory and filename structure
    filename = f"cifar10_results_{args.experiment.name}.txt"
    filepath = os.path.join(args.training.save_model_path, filename)

    # Write the results. 'a' is used to append the results; a new file will be created if it doesn't exist.
    with open(filepath, "a") as file:
        file.write(f"Hidden_dim: {hidden_dim}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%\n")


Replace the last 2 block of layer3 with new block with hidden dim 16
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1401.1385498046875
Permuted TV original model: 1318.2835693359375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 69.24it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 16, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 17
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1442.9952392578125
Permuted TV original model: 1360.1436767578125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 71.51it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 17, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 18
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1483.0216064453125
Permuted TV original model: 1400.0594482421875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 71.44it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 18, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 19
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1538.33935546875
Permuted TV original model: 1451.1243896484375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 73.56it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 19, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 20
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1574.3048095703125
Permuted TV original model: 1487.0372314453125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 75.05it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 20, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 21
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1640.89697265625
Permuted TV original model: 1550.203857421875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 73.79it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 21, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 22
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1676.0743408203125
Permuted TV original model: 1585.9820556640625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.06it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 22, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 23
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1719.2840576171875
Permuted TV original model: 1627.1593017578125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 75.43it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 23, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 24
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1768.282470703125
Permuted TV original model: 1676.5789794921875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 74.69it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 24, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 25
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1819.16845703125
Permuted TV original model: 1726.330810546875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.00it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 25, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 26
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1863.5792236328125
Permuted TV original model: 1767.5157470703125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 73.80it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 26, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 27
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1905.3087158203125
Permuted TV original model: 1807.0479736328125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.22it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 27, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 28
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1964.5340576171875
Permuted TV original model: 1863.0496826171875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.31it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 28, Validation Loss: 1.8076, Validation Accuracy: 75.87%


Replace the last 2 block of layer3 with new block with hidden dim 29
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2001.0850830078125
Permuted TV original model: 1896.6661376953125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.65it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 29, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 30
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2058.7451171875
Permuted TV original model: 1948.0211181640625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 73.65it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 30, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 31
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2096.48388671875
Permuted TV original model: 1989.7352294921875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 69.93it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 31, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 32
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2147.742431640625
Permuted TV original model: 2033.6116943359375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 76.69it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 32, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 33
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2187.97314453125
Permuted TV original model: 2078.010009765625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.48it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 33, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 34
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2222.1923828125
Permuted TV original model: 2113.208740234375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 70.14it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 34, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 35
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2278.254150390625
Permuted TV original model: 2169.591552734375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 71.50it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 35, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 36
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2309.660400390625
Permuted TV original model: 2191.5908203125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 73.26it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 36, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 37
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2370.7421875
Permuted TV original model: 2254.739013671875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.46it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 37, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 38
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2410.807861328125
Permuted TV original model: 2290.69677734375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 77.08it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 38, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 39
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2462.283447265625
Permuted TV original model: 2343.622314453125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.15it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 39, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 40
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2517.60595703125
Permuted TV original model: 2392.82958984375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 77.58it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 40, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 41
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2560.46435546875
Permuted TV original model: 2427.939697265625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 71.38it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 41, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 42
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2607.870849609375
Permuted TV original model: 2477.041748046875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.78it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 42, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 43
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2645.8076171875
Permuted TV original model: 2513.85498046875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.03it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 43, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 44
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2691.095703125
Permuted TV original model: 2557.918701171875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 70.89it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 44, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 45
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2725.190185546875
Permuted TV original model: 2594.971923828125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 68.54it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 45, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 46
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2788.117431640625
Permuted TV original model: 2653.364501953125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 70.63it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 46, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 47
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2852.116455078125
Permuted TV original model: 2723.146728515625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 70.91it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 47, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 48
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2893.088623046875
Permuted TV original model: 2753.6572265625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 68.39it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 48, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 49
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2915.3388671875
Permuted TV original model: 2772.048828125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 70.92it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 49, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 50
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 2971.8388671875
Permuted TV original model: 2841.252197265625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 76.37it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 50, Validation Loss: 1.8076, Validation Accuracy: 75.87%


Replace the last 2 block of layer3 with new block with hidden dim 51
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3001.677978515625
Permuted TV original model: 2870.9501953125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 67.10it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 51, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 52
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3051.837646484375
Permuted TV original model: 2913.075927734375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 73.08it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 52, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 53
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3106.040771484375
Permuted TV original model: 2969.665771484375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 75.89it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 53, Validation Loss: 1.8076, Validation Accuracy: 75.87%


Replace the last 2 block of layer3 with new block with hidden dim 54
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3163.617431640625
Permuted TV original model: 3005.335693359375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 74.18it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 54, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 55
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3216.7041015625
Permuted TV original model: 3062.23583984375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 63.51it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 55, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 56
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3250.066162109375
Permuted TV original model: 3103.95849609375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 74.42it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 56, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 57
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3290.260009765625
Permuted TV original model: 3131.65380859375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 67.23it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 57, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 58
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3348.930419921875
Permuted TV original model: 3179.28369140625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 72.95it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 58, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 59
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3391.8193359375
Permuted TV original model: 3236.480224609375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 73.33it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 59, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 60
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3426.4931640625
Permuted TV original model: 3267.8173828125
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 74.00it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 60, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 61
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3482.376708984375
Permuted TV original model: 3323.4052734375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 75.02it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 61, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 62
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3537.788330078125
Permuted TV original model: 3363.88525390625
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 75.55it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 62, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 63
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 3598.633056640625
Permuted TV original model: 3416.765380859375
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 75.27it/s]


Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 63, Validation Loss: 1.8076, Validation Accuracy: 75.86%


Replace the last 2 block of layer3 with new block with hidden dim 64
Loading pretrained weights for resnet20
Smooth the parameters of the model
Old TV original model: 1321.4044189453125
Permuted TV original model: 1039.069091796875
Applying EMA


100%|██████████| 157/157 [00:02<00:00, 75.43it/s]

Test using model {'type': 'ResNet20', 'pretrained_path': 'resnet20-12fca82f.th', 'smooth': True}: hidden_dim 64, Validation Loss: 1.8076, Validation Accuracy: 75.86%





