## Import

In [1]:
import os
import random
import copy

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

In [3]:
from neumeta.models import create_mnist_model as create_model
from neumeta.utils import (
    parse_args, print_omegaconf,
    load_checkpoint, save_checkpoint,
    set_seed,
    get_cifar10, get_dataset,
    sample_coordinates, sample_subset, shuffle_coordinates_all,
    get_hypernetwork, get_optimizer, 
    sample_weights,
    weighted_regression_loss, validate_single, AverageMeter, EMA,
    sample_single_model, sample_merge_model,
)

In [4]:
from smooth.permute import PermutationManager, compute_tv_loss_for_network

## Functions

### 1 Find maximum dimension of the model

In [5]:
def find_max_dim(model_cls):
    """Find maximum dimension of the model"""
    # Get the learnable parameters of the model
    checkpoint = model_cls.learnable_parameter 

    # Set the maximum value to the length of the checkpoint
    max_value = len(checkpoint)

    # Iterate over the new model's weight
    for i, (k, tensor) in enumerate(checkpoint.items()):
        # Handle 2D tensors (e.g., weight matrices) 
        if len(tensor.shape) == 4:
            coords = [tensor.shape[0], tensor.shape[1]]
            max_value = max(max_value, max(coords))
        # Handle 1D tensors (e.g., biases)
        elif len(tensor.shape) == 1:
            max_value = max(max_value, tensor.shape[0])
    
    return max_value

### 2 Initialize wandb

In [6]:
def initialize_wandb(config):
    import time
    """
    Initializes Weights and Biases (wandb) with the given configuration.
    
    Args:
        configuration (dict): Configuration parameters for the run.
    """
    # Name the run using current time and configuration name
    run_name = f"{time.strftime('%Y%m%d%H%M%S')}-{config.experiment.name}"
    
    wandb.init(project="ninr-trial", name=run_name, config=dict(config), group='LeNet')

### 3 Initialize model dictionary

In [7]:
def init_model_dict(args, device):
    """
    Initializes a dictionary of models for each dimension in the given range, along with ground truth models for the starting dimension.

    Args:
        args: An object containing the arguments for initializing the models.

    Returns:
        dim_dict: A dictionary containing the models for each dimension, along with their corresponding coordinates, keys, indices, size, and ground truth models.
        gt_model_dict: A dictionary containing the ground truth models for the starting dimension.
    """
    dim_dict = {}
    gt_model_dict = {}
    
    # Create a model for each dimension in dimensions range
    for dim in args.dimensions.range:
        model_cls = create_model(args.model.type,
                                 hidden_dim=dim,
                                 path=args.model.pretrained_path,
                                 # smooth=args.model.smooth (Not used in LeNet)
                                 ).to(device)
        model_cls.eval()
        # Sample the coordinates, keys, indices, and the size for the model
        coords_tensor, keys_list, indices_list, size_list = sample_coordinates(model_cls)
        # Add the model, coordinates, keys, indices, size, and key mask to the dictionary
        dim_dict[f"{dim}"] = (model_cls, coords_tensor, keys_list, indices_list, size_list, None)

        # Print to makes line better
        print('\n')
        
        # If the dimension is the starting dimension (the dimension of pretrained_model), add the ground truth model to the dictionary
        if dim == args.dimensions.start:
            print(f"Loading model for dim {dim}")
            model_trained = create_model(args.model.type, 
                                         hidden_dim=dim, 
                                         path=args.model.pretrained_path, 
                                         # smooth=args.model.smooth
                                         ).to(device)
            model_trained.load_state_dict(torch.load(args.model.pretrained_path))
            model_trained.eval()

            # Smooths the ground truth model
            print("TV original model: ", compute_tv_loss_for_network(model_trained, lambda_tv=1.0).item())
            input_tensor = torch.randn(1, 1, 28, 28).to(device)
            permute_func = PermutationManager(model_trained, input_tensor)
            permute_dict = permute_func.compute_permute_dict()
            model_trained = permute_func.apply_permutations(permute_dict, ignored_keys=[
                ('conv_1.weight', 'in_channels'), 
                ('linear.weight', 'out_channels'), 
                ('linear.bias', 'out_channels')
            ])
            print("TV permutated model: ", compute_tv_loss_for_network(model_trained, lambda_tv=1.0).item())

            gt_model_dict[f'{dim}'] = copy.deepcopy(model_trained)
    
    return dim_dict, gt_model_dict

### 4 Training function for target model of a random dimension

In [8]:
# Function to train the model for one epoch
def train_one_epoch(model, train_loader, optimizer, criterion, dim_dict, gt_model_dict, epoch_idx, ema=None, args=None, device='cpu'):
    # Set the model to training mode
    model.train()
    total_loss = 0.0

    # Initialize AverageMeter objects to track the losses
    losses = AverageMeter()
    cls_losses = AverageMeter()
    reg_losses = AverageMeter()
    reconstruct_losses = AverageMeter()

    # Iterate over the training data
    for batch_idx, (x, target) in enumerate(train_loader):
        # Zero the gradients
        optimizer.zero_grad()

        # Preprocess input
        # ------------------------------------------------------------------------------------------------------
        # Move the data to the device
        x, target = x.to(device), target.to(device)
        # Choose a random hidden dimension
        hidden_dim = random.choice(args.dimensions.range)
        # Get the model class, coordinates, keys, indices, size, and key mask for the chosen dimension
        model_cls, coords_tensor, keys_list, indices_list, size_list, key_mask = dim_dict[f"{hidden_dim}"]
        # Sample a subset the input tensor of the coordinates, keys, indices, size, and selected keys
        coords_tensor, keys_list, indices_list, size_list, selected_keys = sample_subset(coords_tensor,
                                                                                         keys_list,
                                                                                         indices_list,
                                                                                         size_list,
                                                                                         key_mask,
                                                                                         ratio=args.ratio)
        # Add noise to the coordinates if specified
        if args.training.coordinate_noise > 0.0:
            coords_tensor = coords_tensor + (torch.rand_like(coords_tensor) - 0.5) * args.training.coordinate_noise


        # Main task of hypernetwork and target network
        # ------------------------------------------------------------------------------------------------------
        # Sample the weights for the target model using hypernetwork
        model_cls, reconstructed_weights = sample_weights(model, model_cls,
                                                          coords_tensor, keys_list, indices_list, size_list, key_mask, selected_keys,
                                                          device=device, NORM=args.dimensions.norm)
        # Forward pass
        predict = model_cls(x)


        # Compute losses
        # ------------------------------------------------------------------------------------------------------
        # Compute classification loss
        cls_loss = criterion(predict, target) 
        # Compute regularization loss
        reg_loss = sum([torch.norm(w, p=2) for w in reconstructed_weights])
        # Compute reconstruction loss if ground truth model is available
        if f"{hidden_dim}" in gt_model_dict:
            gt_model = gt_model_dict[f"{hidden_dim}"]
            gt_selected_weights = [
                w for k, w in gt_model.learnable_parameter.items() if k in selected_keys]

            reconstruct_loss = weighted_regression_loss(
                reconstructed_weights, gt_selected_weights)
        else:
            reconstruct_loss = torch.tensor(0.0)
        # Compute the total loss
        loss = args.hyper_model.loss_weight.ce_weight * cls_loss + args.hyper_model.loss_weight.reg_weight * \
            reg_loss + args.hyper_model.loss_weight.recon_weight * reconstruct_loss


        # Compute gradients and update weights
        # ------------------------------------------------------------------------------------------------------
        # Zero the gradients of the updated weights
        for updated_weight in model_cls.parameters():
            updated_weight.grad = None

        # Compute the gradients of the reconstructed weights
        loss.backward(retain_graph=True)
        torch.autograd.backward(reconstructed_weights, [
                                w.grad for k, w in model_cls.named_parameters() if k in selected_keys])
        
        # Clip the gradients if specified
        if args.training.get('clip_grad', 0.0) > 0:
            torch.nn.utils.clip_grad_value_(
                model.parameters(), args.training.clip_grad)
            
        # Update the weights
        optimizer.step()

        # Update the EMA if specified
        if ema:
            ema.update()  # Update the EMA after each training step
        total_loss += loss.item()

        # Update the AverageMeter objects
        losses.update(loss.item())
        cls_losses.update(cls_loss.item())
        reg_losses.update(reg_loss.item())
        reconstruct_losses.update(reconstruct_loss.item())


        # Log (or plot) losses
        # ------------------------------------------------------------------------------------------------------
        # Log the losses and learning rate to wandb
        if batch_idx % args.experiment.log_interval == 0:
            wandb.log({
                "Loss": losses.avg,
                "Cls Loss": cls_losses.avg,
                "Reg Loss": reg_losses.avg,
                "Reconstruct Loss": reconstruct_losses.avg,
                # "Recon Weight": args.hyper_model.loss_weight.recon_weight,  # Newly added for LeNet
                "Learning rate": optimizer.param_groups[0]['lr']
            }, step=batch_idx + epoch_idx * len(train_loader))
            # Print the losses and learning rate
            print(
                f"Iteration {batch_idx}: Loss = {losses.avg:.4f}, Reg Loss = {reg_losses.avg:.4f}, Reconstruct Loss = {reconstruct_losses.avg:.4f}, Cls Loss = {cls_losses.avg:.4f}, Learning rate = {optimizer.param_groups[0]['lr']:.4e}")
    
    # Returns the training loss, structure of network in each dimension, and the original structure of pretrained network
    return losses.avg, dim_dict, gt_model_dict

## Main

### 0 Set device to GPU

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 1 Parsing arguments for inputs

In [10]:
CONFIG_PATH = 'neumeta/config/mnist/LeNet_mnist_8-32_coodnoise.yaml'
RATIO = '1.0'

In [11]:
argv_train = ['--config', CONFIG_PATH, '--ratio', RATIO]

In [12]:
args = parse_args(argv_train)  # Parse arguments
print_omegaconf(args)  # Print arguments

Loading base config from toy/experiments/base_config.yaml
+--------------------------------------+----------------------------------------------------------------------------------------------------+
|                 Key                  |                                               Value                                                |
+--------------------------------------+----------------------------------------------------------------------------------------------------+
|           experiment.name            |                                       mnist_lenet_8-32-noise                                       |
|         experiment.recononly         |                                                 0                                                  |
|        experiment.num_epochs         |                                                 50                                                 |
|       experiment.log_interval        |                                                 5

In [13]:
set_seed(args.experiment.seed)

Setting seed... 42 for reproducibility


### 2 Get training and validation data (in dataloader format)

In [14]:
train_loader, val_loader = get_dataset(args.training.dataset, args.training.batch_size, strong_transform=args.training.get('strong_aug', None))

Using dataset: mnist with batch size: 128 and strong transform: None


### 3 Create target model

#### 3.0 Create the model

In [15]:
model = create_model(args.model.type,
                     hidden_dim=args.dimensions.start,
                     path=args.model.pretrained_path,
                     # smooth=args.model.smooth
                     ).to(device)

In [16]:
# Load LeNet pretrained weight
model.load_state_dict(torch.load(args.model.pretrained_path))
model.eval()

MnistNet(
  (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv_2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (conv_3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (f_1): ReLU()
  (f_2): ReLU()
  (f_3): ReLU()
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (linear): Linear(in_features=128, out_features=10, bias=True)
)

#### 3.1 Print the structure and shape of the model

In [17]:
model

MnistNet(
  (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv_2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (conv_3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (f_1): ReLU()
  (f_2): ReLU()
  (f_3): ReLU()
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (linear): Linear(in_features=128, out_features=10, bias=True)
)

In [18]:
for i, (k, tensor) in enumerate(model.learnable_parameter.items()):
    print(k, tensor.shape)

conv_1.weight torch.Size([32, 1, 3, 3])
conv_1.bias torch.Size([32])
conv_2.weight torch.Size([64, 32, 5, 5])
conv_2.bias torch.Size([64])
conv_3.weight torch.Size([128, 64, 5, 5])
conv_3.bias torch.Size([128])
linear.weight torch.Size([10, 128])
linear.bias torch.Size([10])


#### 3.2 The maximum dimension of the target model

In [19]:
# Print the maximum dimension of the model
# print(f'Maximum DIM: {find_max_dim(model)}')

#### 3.3 Validate the accuracy of pretrained network

In [20]:
# Validate the model for the starting dimension (its pretrained form)
val_loss, acc = validate_single(model, val_loader, nn.CrossEntropyLoss(), args=args)
print(f'Initial Permutated model Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc * 100:.2f}%')

100%|██████████| 79/79 [00:01<00:00, 57.37it/s]

Initial Permutated model Validation Loss: 0.0621, Validation Accuracy: 98.14%





In [21]:
# Get the learnable parameters of the model
checkpoint = model.learnable_parameter
# Get the number of parameters
number_param = len(checkpoint)

In [22]:
# Print the keys of the parameters and the number of parameters
print(f"Parameters keys: {model.keys}")
print(f"Number of parameters to be learned: {number_param}")

Parameters keys: ['conv_1.weight', 'conv_1.bias', 'conv_2.weight', 'conv_2.bias', 'conv_3.weight', 'conv_3.bias', 'linear.weight', 'linear.bias']
Number of parameters to be learned: 8


### 4 Create the hypernetwork

#### 4.0 Create the model

In [23]:
# Get the hypermodel
hyper_model = get_hypernetwork(args, number_param)

Hyper model type: resmlp
Using scalar 0.1
num_freqs:  16 <class 'int'>


#### 4.1 Print model structure

In [24]:
hyper_model

NeRF_ResMLP_Compose(
  (positional_encoding): PositionalEncoding()
  (model): ModuleList(
    (0-7): 8 x NeRF_MLP_Residual_Scaled(
      (initial_layer): Linear(in_features=198, out_features=256, bias=True)
      (residual_blocks): ModuleList(
        (0-2): 3 x Linear(in_features=256, out_features=256, bias=True)
      )
      (scalars): ParameterList(
          (0): Parameter containing: [torch.float32 of size  (cuda:0)]
          (1): Parameter containing: [torch.float32 of size  (cuda:0)]
          (2): Parameter containing: [torch.float32 of size  (cuda:0)]
      )
      (act): ReLU(inplace=True)
      (output_layer): Linear(in_features=256, out_features=25, bias=True)
    )
  )
)

#### 4.2 Initialize EMA to track only a smooth version of the model weight

In [25]:
# Initialize the EMA
ema = EMA(hyper_model, decay=args.hyper_model.ema_decay)

### 5 Get Loss function, Optimizer, and Scheduler

In [26]:
criterion, val_criterion, optimizer, scheduler = get_optimizer(args, hyper_model)

In [27]:
print(f'Criterion: {criterion}\nVal_criterion: {val_criterion}\nOptimizer: {optimizer}\nScheduler: {scheduler}')

Criterion: CrossEntropyLoss()
Val_criterion: CrossEntropyLoss()
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 0.0
)
Scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x0000021C47E4FFD0>


### 6 Training loop

#### 6.1 Initialize training parameters

In [28]:
# Initialize the starting epoch and best accuracy
start_epoch = 0
best_acc = 0.0

#### 6.2 Directory to save the model

In [29]:
# Create the directory to save the model
os.makedirs(args.training.save_model_path, exist_ok=True)

#### 6.3 Resume training loop

In [30]:
if args.resume_from:
        print(f"Resuming from checkpoint: {args.resume_from}")
        checkpoint_info = load_checkpoint(args.resume_from, hyper_model, optimizer, ema)
        start_epoch = checkpoint_info['epoch']
        best_acc = checkpoint_info['best_acc']
        print(f"Resuming from epoch: {start_epoch}, best accuracy: {best_acc*100:.2f}%")
        # Note: If there are more elements to retrieve, do so here.

#### 6.4 Initialize wandb for plotting

In [31]:
# Initialize wandb
initialize_wandb(args)

[34m[1mwandb[0m: Currently logged in as: [33mefradosuryadi[0m ([33mefradosuryadi-universitas-indonesia[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


#### 6.5 Initialize model dictionary for each dimension and shuffle it

In [32]:
# Initialize model dictionary
dim_dict, gt_model_dict = init_model_dict(args, device)



















































Loading model for dim 32
TV original model:  3088.543212890625
TV permutated model:  2175.4375


In [33]:
gt_model_dict

{'32': MnistNet(
   (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
   (conv_2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
   (conv_3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
   (f_1): ReLU()
   (f_2): ReLU()
   (f_3): ReLU()
   (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
   (linear): Linear(in_features=128, out_features=10, bias=True)
 )}

In [34]:
dim_dict

{'8': (MnistNet(
    (conv_1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
    (conv_2): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
    (conv_3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
    (f_1): ReLU()
    (f_2): ReLU()
    (f_3): ReLU()
    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (linear): Linear(in_features=32, out_features=10, bias=True)
  ),
  tensor([[ 0.,  0.,  0.,  8.,  8.,  1.],
          [ 0.,  1.,  0.,  8.,  8.,  1.],
          [ 0.,  2.,  0.,  8.,  8.,  1.],
          ...,
          [ 7.,  7.,  0.,  8., 10.,  1.],
          [ 7.,  8.,  0.,  8., 10.,  1.],
          [ 7.,  9.,  0.,  8., 10.,  1.]]),
  array(['conv_1.weight', 'conv_1.weight', 'conv_1.weight', ...,
         'linear.bias', 'linear.bias', 'linear.bias'], dtype='<U13'),
  tensor([[0, 0, 3, 3],
          [1, 0, 3, 3],
          [2, 0, 3, 3],
          ...,
 

In [35]:
dim_dict = shuffle_coordinates_all(dim_dict)
dim_dict

{'8': (MnistNet(
    (conv_1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
    (conv_2): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
    (conv_3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
    (f_1): ReLU()
    (f_2): ReLU()
    (f_3): ReLU()
    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (linear): Linear(in_features=32, out_features=10, bias=True)
  ),
  tensor([[ 0.,  0.,  0.,  8.,  8.,  1.],
          [ 0.,  1.,  0.,  8.,  8.,  1.],
          [ 0.,  2.,  0.,  8.,  8.,  1.],
          ...,
          [ 7.,  7.,  0.,  8., 10.,  1.],
          [ 7.,  8.,  0.,  8., 10.,  1.],
          [ 7.,  9.,  0.,  8., 10.,  1.]]),
  array(['conv_1.weight', 'conv_1.weight', 'conv_1.weight', ...,
         'linear.bias', 'linear.bias', 'linear.bias'], dtype='<U13'),
  tensor([[0, 0, 3, 3],
          [1, 0, 3, 3],
          [2, 0, 3, 3],
          ...,
 

#### 6.6 Hypernetwork training loop

In [36]:
args.experiment.num_epochs

50

In [37]:
# Iterate over the epochs
for epoch in range(start_epoch, args.experiment.num_epochs):
    train_loss, dim_dict, gt_model_dict = train_one_epoch(hyper_model, train_loader, optimizer, criterion, 
                                                          dim_dict, gt_model_dict, epoch_idx=epoch, ema=ema, 
                                                          args=args, device=device)
    # Step the scheduler
    scheduler.step()

    # Print the training loss and learning rate
    print(f"Epoch [{epoch+1}/{args.experiment.num_epochs}], Training Loss: {train_loss:.4f}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}")

    # If it's time to evaluate the model
    if (epoch + 1) % args.experiment.eval_interval == 0:
        # Apply EMA if it is specified
        if ema:
            ema.apply()  # Save the weights of original model created before training_loop
        
        # Sample the merged model (create model of same structure before training loop by using the hypernetwork)
        # And then test the performance of the hypernetwork by seeing how good it is in generating the weights
        model = sample_merge_model(hyper_model, model, args) 
        # Validate the merged model
        val_loss, acc = validate_single(model, val_loader, val_criterion, args=args)

        # If EMA is specified, restore the original weights
        if ema:
            ema.restore()  # Restore the original weights to the weights of the pretrained networks

        # Log the validation loss and accuracy to wandb
        wandb.log({
            "Validation Loss": val_loss,
            "Validation Accuracy": acc
        })
        # Print the validation loss and accuracy
        print(f"Epoch [{epoch+1}/{args.experiment.num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")
        print('\n\n')

        # Save the checkpoint if the accuracy is better than the previous best
        if acc > best_acc:
            best_acc = acc
            save_checkpoint(f"{args.training.save_model_path}/mnist_nerf_best.pth",hyper_model,optimizer,ema,epoch,best_acc)
            print(f"Checkpoint saved at epoch {epoch} with accuracy: {best_acc*100:.2f}%")


Iteration 0: Loss = 2.3828, Reg Loss = 11.6718, Reconstruct Loss = 0.0000, Cls Loss = 2.3817, Learning rate = 1.0000e-03
Iteration 50: Loss = 2.3194, Reg Loss = 10.1162, Reconstruct Loss = 0.0010, Cls Loss = 2.3174, Learning rate = 1.0000e-03
Iteration 100: Loss = 2.3180, Reg Loss = 9.9594, Reconstruct Loss = 0.0007, Cls Loss = 2.3163, Learning rate = 1.0000e-03
Iteration 150: Loss = 2.3146, Reg Loss = 9.3854, Reconstruct Loss = 0.0007, Cls Loss = 2.3129, Learning rate = 1.0000e-03
Iteration 200: Loss = 2.3007, Reg Loss = 14.8758, Reconstruct Loss = 0.0029, Cls Loss = 2.2963, Learning rate = 1.0000e-03
Iteration 250: Loss = 2.2684, Reg Loss = 25.5748, Reconstruct Loss = 0.0035, Cls Loss = 2.2623, Learning rate = 1.0000e-03
Iteration 300: Loss = 2.2431, Reg Loss = 34.4771, Reconstruct Loss = 0.0029, Cls Loss = 2.2368, Learning rate = 1.0000e-03
Iteration 350: Loss = 2.2250, Reg Loss = 40.6304, Reconstruct Loss = 0.0032, Cls Loss = 2.2177, Learning rate = 1.0000e-03
Iteration 400: Loss =

100%|██████████| 79/79 [00:01<00:00, 54.10it/s]


Epoch [1/50], Validation Loss: 2.1636, Validation Accuracy: 20.85%



Checkpoint saved at epoch 0 with accuracy: 20.85%
Iteration 0: Loss = 1.9897, Reg Loss = 84.2579, Reconstruct Loss = 0.0000, Cls Loss = 1.9812, Learning rate = 1.0000e-03
Iteration 50: Loss = 2.1158, Reg Loss = 85.4009, Reconstruct Loss = 0.0069, Cls Loss = 2.1004, Learning rate = 1.0000e-03
Iteration 100: Loss = 2.1027, Reg Loss = 82.5665, Reconstruct Loss = 0.0057, Cls Loss = 2.0888, Learning rate = 1.0000e-03
Iteration 150: Loss = 2.0919, Reg Loss = 82.5984, Reconstruct Loss = 0.0038, Cls Loss = 2.0799, Learning rate = 1.0000e-03
Iteration 200: Loss = 2.0891, Reg Loss = 81.0112, Reconstruct Loss = 0.0063, Cls Loss = 2.0747, Learning rate = 1.0000e-03
Iteration 250: Loss = 2.0877, Reg Loss = 78.5290, Reconstruct Loss = 0.0074, Cls Loss = 2.0725, Learning rate = 1.0000e-03
Iteration 300: Loss = 2.0844, Reg Loss = 76.2417, Reconstruct Loss = 0.0071, Cls Loss = 2.0697, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 55.98it/s]


Epoch [2/50], Validation Loss: 2.0498, Validation Accuracy: 19.61%



Iteration 0: Loss = 2.1258, Reg Loss = 71.7127, Reconstruct Loss = 0.0000, Cls Loss = 2.1186, Learning rate = 1.0000e-03
Iteration 50: Loss = 2.1568, Reg Loss = 73.8484, Reconstruct Loss = 0.0077, Cls Loss = 2.1417, Learning rate = 1.0000e-03
Iteration 100: Loss = 2.1069, Reg Loss = 75.2805, Reconstruct Loss = 0.0077, Cls Loss = 2.0916, Learning rate = 1.0000e-03
Iteration 150: Loss = 2.0954, Reg Loss = 74.3859, Reconstruct Loss = 0.0085, Cls Loss = 2.0795, Learning rate = 1.0000e-03
Iteration 200: Loss = 2.0860, Reg Loss = 73.2251, Reconstruct Loss = 0.0064, Cls Loss = 2.0723, Learning rate = 1.0000e-03
Iteration 250: Loss = 2.0816, Reg Loss = 72.1324, Reconstruct Loss = 0.0081, Cls Loss = 2.0662, Learning rate = 1.0000e-03
Iteration 300: Loss = 2.0751, Reg Loss = 70.8331, Reconstruct Loss = 0.0077, Cls Loss = 2.0603, Learning rate = 1.0000e-03
Iteration 350: Loss = 2.0681, Reg Loss = 70.2187, Reconstruct Loss = 0.0

100%|██████████| 79/79 [00:01<00:00, 57.61it/s]


Epoch [3/50], Validation Loss: 2.0209, Validation Accuracy: 21.73%



Checkpoint saved at epoch 2 with accuracy: 21.73%
Iteration 0: Loss = 2.0241, Reg Loss = 63.8759, Reconstruct Loss = 0.0000, Cls Loss = 2.0178, Learning rate = 1.0000e-03
Iteration 50: Loss = 2.0533, Reg Loss = 61.5759, Reconstruct Loss = 0.0110, Cls Loss = 2.0362, Learning rate = 1.0000e-03
Iteration 100: Loss = 2.0374, Reg Loss = 61.7143, Reconstruct Loss = 0.0067, Cls Loss = 2.0245, Learning rate = 1.0000e-03
Iteration 150: Loss = 2.0316, Reg Loss = 62.9099, Reconstruct Loss = 0.0045, Cls Loss = 2.0208, Learning rate = 1.0000e-03
Iteration 200: Loss = 2.0269, Reg Loss = 63.2430, Reconstruct Loss = 0.0050, Cls Loss = 2.0155, Learning rate = 1.0000e-03
Iteration 250: Loss = 2.0228, Reg Loss = 63.4183, Reconstruct Loss = 0.0048, Cls Loss = 2.0117, Learning rate = 1.0000e-03
Iteration 300: Loss = 2.0164, Reg Loss = 63.9445, Reconstruct Loss = 0.0044, Cls Loss = 2.0056, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 57.06it/s]


Epoch [4/50], Validation Loss: 1.9219, Validation Accuracy: 27.99%



Checkpoint saved at epoch 3 with accuracy: 27.99%
Iteration 0: Loss = 1.8504, Reg Loss = 87.1445, Reconstruct Loss = 0.0000, Cls Loss = 1.8416, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.8924, Reg Loss = 86.2481, Reconstruct Loss = 0.0135, Cls Loss = 1.8703, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.8803, Reg Loss = 85.9406, Reconstruct Loss = 0.0110, Cls Loss = 1.8607, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.8564, Reg Loss = 83.6909, Reconstruct Loss = 0.0086, Cls Loss = 1.8394, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.8499, Reg Loss = 83.1222, Reconstruct Loss = 0.0081, Cls Loss = 1.8334, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.8407, Reg Loss = 82.6490, Reconstruct Loss = 0.0094, Cls Loss = 1.8230, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.8434, Reg Loss = 81.9813, Reconstruct Loss = 0.0083, Cls Loss = 1.8269, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 57.23it/s]


Epoch [5/50], Validation Loss: 1.7398, Validation Accuracy: 28.00%



Checkpoint saved at epoch 4 with accuracy: 28.00%
Iteration 0: Loss = 2.0296, Reg Loss = 81.6159, Reconstruct Loss = 0.0000, Cls Loss = 2.0214, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.7834, Reg Loss = 94.0362, Reconstruct Loss = 0.0111, Cls Loss = 1.7630, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.7468, Reg Loss = 94.6226, Reconstruct Loss = 0.0080, Cls Loss = 1.7293, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.7287, Reg Loss = 94.7949, Reconstruct Loss = 0.0101, Cls Loss = 1.7091, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.7164, Reg Loss = 94.4905, Reconstruct Loss = 0.0104, Cls Loss = 1.6966, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.7090, Reg Loss = 95.0189, Reconstruct Loss = 0.0094, Cls Loss = 1.6901, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.7022, Reg Loss = 94.9538, Reconstruct Loss = 0.0094, Cls Loss = 1.6833, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 59.06it/s]


Epoch [6/50], Validation Loss: 1.5930, Validation Accuracy: 34.56%



Checkpoint saved at epoch 5 with accuracy: 34.56%
Iteration 0: Loss = 1.4436, Reg Loss = 92.0683, Reconstruct Loss = 0.0000, Cls Loss = 1.4344, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.6196, Reg Loss = 88.2438, Reconstruct Loss = 0.0073, Cls Loss = 1.6035, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.6300, Reg Loss = 86.7180, Reconstruct Loss = 0.0055, Cls Loss = 1.6158, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.6339, Reg Loss = 85.3734, Reconstruct Loss = 0.0064, Cls Loss = 1.6190, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.6410, Reg Loss = 84.6970, Reconstruct Loss = 0.0087, Cls Loss = 1.6239, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.6410, Reg Loss = 84.0228, Reconstruct Loss = 0.0088, Cls Loss = 1.6238, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.6353, Reg Loss = 83.7846, Reconstruct Loss = 0.0091, Cls Loss = 1.6178, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 56.66it/s]


Epoch [7/50], Validation Loss: 1.5703, Validation Accuracy: 34.94%



Checkpoint saved at epoch 6 with accuracy: 34.94%
Iteration 0: Loss = 1.5532, Reg Loss = 76.3154, Reconstruct Loss = 0.0000, Cls Loss = 1.5456, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.5981, Reg Loss = 77.7150, Reconstruct Loss = 0.0093, Cls Loss = 1.5810, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.5958, Reg Loss = 78.5416, Reconstruct Loss = 0.0067, Cls Loss = 1.5813, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.5961, Reg Loss = 78.8540, Reconstruct Loss = 0.0058, Cls Loss = 1.5824, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.5960, Reg Loss = 78.3896, Reconstruct Loss = 0.0089, Cls Loss = 1.5792, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.5963, Reg Loss = 77.9999, Reconstruct Loss = 0.0091, Cls Loss = 1.5794, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.5920, Reg Loss = 77.9520, Reconstruct Loss = 0.0076, Cls Loss = 1.5766, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 57.86it/s]


Epoch [8/50], Validation Loss: 1.5425, Validation Accuracy: 45.34%



Checkpoint saved at epoch 7 with accuracy: 45.34%
Iteration 0: Loss = 1.5288, Reg Loss = 78.4068, Reconstruct Loss = 0.0000, Cls Loss = 1.5209, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.4687, Reg Loss = 84.3875, Reconstruct Loss = 0.0052, Cls Loss = 1.4551, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.4464, Reg Loss = 85.5270, Reconstruct Loss = 0.0052, Cls Loss = 1.4327, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.4531, Reg Loss = 84.4830, Reconstruct Loss = 0.0129, Cls Loss = 1.4317, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.4562, Reg Loss = 83.8119, Reconstruct Loss = 0.0115, Cls Loss = 1.4364, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.4459, Reg Loss = 83.7982, Reconstruct Loss = 0.0112, Cls Loss = 1.4264, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.4423, Reg Loss = 83.2111, Reconstruct Loss = 0.0099, Cls Loss = 1.4240, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 58.29it/s]


Epoch [9/50], Validation Loss: 1.3079, Validation Accuracy: 48.18%



Checkpoint saved at epoch 8 with accuracy: 48.18%
Iteration 0: Loss = 1.5264, Reg Loss = 85.7052, Reconstruct Loss = 0.0000, Cls Loss = 1.5179, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.2944, Reg Loss = 79.3669, Reconstruct Loss = 0.0122, Cls Loss = 1.2743, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.3120, Reg Loss = 79.0329, Reconstruct Loss = 0.0082, Cls Loss = 1.2959, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.3143, Reg Loss = 78.8184, Reconstruct Loss = 0.0100, Cls Loss = 1.2964, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.3138, Reg Loss = 78.3893, Reconstruct Loss = 0.0105, Cls Loss = 1.2955, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.3124, Reg Loss = 78.3636, Reconstruct Loss = 0.0091, Cls Loss = 1.2955, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.3109, Reg Loss = 78.5814, Reconstruct Loss = 0.0076, Cls Loss = 1.2955, Learning rate = 1.0000e-03
Iteration 350: Loss = 

100%|██████████| 79/79 [00:01<00:00, 55.47it/s]


Epoch [10/50], Validation Loss: 1.1786, Validation Accuracy: 56.95%



Checkpoint saved at epoch 9 with accuracy: 56.95%
Iteration 0: Loss = 1.4018, Reg Loss = 71.8321, Reconstruct Loss = 0.0000, Cls Loss = 1.3946, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.3135, Reg Loss = 79.0335, Reconstruct Loss = 0.0034, Cls Loss = 1.3023, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.2908, Reg Loss = 79.7084, Reconstruct Loss = 0.0017, Cls Loss = 1.2811, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.2834, Reg Loss = 79.5901, Reconstruct Loss = 0.0037, Cls Loss = 1.2717, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.2733, Reg Loss = 80.2010, Reconstruct Loss = 0.0038, Cls Loss = 1.2615, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.2654, Reg Loss = 80.7228, Reconstruct Loss = 0.0048, Cls Loss = 1.2525, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.2655, Reg Loss = 80.7283, Reconstruct Loss = 0.0059, Cls Loss = 1.2516, Learning rate = 1.0000e-03
Iteration 350: Loss =

100%|██████████| 79/79 [00:01<00:00, 57.84it/s]


Epoch [11/50], Validation Loss: 1.1611, Validation Accuracy: 54.12%



Iteration 0: Loss = 1.2745, Reg Loss = 74.7574, Reconstruct Loss = 0.0000, Cls Loss = 1.2670, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.2641, Reg Loss = 82.6557, Reconstruct Loss = 0.0206, Cls Loss = 1.2353, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.2253, Reg Loss = 83.4716, Reconstruct Loss = 0.0143, Cls Loss = 1.2027, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.2168, Reg Loss = 83.8024, Reconstruct Loss = 0.0111, Cls Loss = 1.1973, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.2105, Reg Loss = 83.6409, Reconstruct Loss = 0.0109, Cls Loss = 1.1913, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.2038, Reg Loss = 83.5626, Reconstruct Loss = 0.0096, Cls Loss = 1.1858, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.2004, Reg Loss = 83.2781, Reconstruct Loss = 0.0097, Cls Loss = 1.1823, Learning rate = 1.0000e-03
Iteration 350: Loss = 1.2003, Reg Loss = 82.9834, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 57.36it/s]


Epoch [12/50], Validation Loss: 1.0854, Validation Accuracy: 62.91%



Checkpoint saved at epoch 11 with accuracy: 62.91%
Iteration 0: Loss = 1.1142, Reg Loss = 77.0061, Reconstruct Loss = 0.0000, Cls Loss = 1.1065, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.1644, Reg Loss = 82.6709, Reconstruct Loss = 0.0048, Cls Loss = 1.1513, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.1716, Reg Loss = 81.9297, Reconstruct Loss = 0.0069, Cls Loss = 1.1565, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.1677, Reg Loss = 82.5968, Reconstruct Loss = 0.0088, Cls Loss = 1.1506, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.1625, Reg Loss = 82.8605, Reconstruct Loss = 0.0104, Cls Loss = 1.1438, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.1549, Reg Loss = 82.7746, Reconstruct Loss = 0.0122, Cls Loss = 1.1344, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.1537, Reg Loss = 82.5160, Reconstruct Loss = 0.0110, Cls Loss = 1.1345, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 57.06it/s]


Epoch [13/50], Validation Loss: 1.0465, Validation Accuracy: 64.31%



Checkpoint saved at epoch 12 with accuracy: 64.31%
Iteration 0: Loss = 1.3181, Reg Loss = 77.8500, Reconstruct Loss = 0.0000, Cls Loss = 1.3103, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.1329, Reg Loss = 81.7974, Reconstruct Loss = 0.0085, Cls Loss = 1.1163, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.1266, Reg Loss = 83.0349, Reconstruct Loss = 0.0043, Cls Loss = 1.1140, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.1205, Reg Loss = 83.2570, Reconstruct Loss = 0.0045, Cls Loss = 1.1077, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.1113, Reg Loss = 83.0288, Reconstruct Loss = 0.0057, Cls Loss = 1.0973, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.1132, Reg Loss = 83.2034, Reconstruct Loss = 0.0074, Cls Loss = 1.0975, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.1112, Reg Loss = 83.3857, Reconstruct Loss = 0.0076, Cls Loss = 1.0952, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 60.01it/s]


Epoch [14/50], Validation Loss: 0.9655, Validation Accuracy: 67.68%



Checkpoint saved at epoch 13 with accuracy: 67.68%
Iteration 0: Loss = 0.9838, Reg Loss = 69.4891, Reconstruct Loss = 0.0000, Cls Loss = 0.9769, Learning rate = 1.0000e-03
Iteration 50: Loss = 1.0633, Reg Loss = 80.1050, Reconstruct Loss = 0.0172, Cls Loss = 1.0381, Learning rate = 1.0000e-03
Iteration 100: Loss = 1.0441, Reg Loss = 81.5505, Reconstruct Loss = 0.0109, Cls Loss = 1.0251, Learning rate = 1.0000e-03
Iteration 150: Loss = 1.0374, Reg Loss = 81.7882, Reconstruct Loss = 0.0105, Cls Loss = 1.0187, Learning rate = 1.0000e-03
Iteration 200: Loss = 1.0355, Reg Loss = 81.4472, Reconstruct Loss = 0.0090, Cls Loss = 1.0183, Learning rate = 1.0000e-03
Iteration 250: Loss = 1.0290, Reg Loss = 81.1436, Reconstruct Loss = 0.0095, Cls Loss = 1.0114, Learning rate = 1.0000e-03
Iteration 300: Loss = 1.0238, Reg Loss = 80.7352, Reconstruct Loss = 0.0108, Cls Loss = 1.0050, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 54.72it/s]


Epoch [15/50], Validation Loss: 0.8734, Validation Accuracy: 73.56%



Checkpoint saved at epoch 14 with accuracy: 73.56%
Iteration 0: Loss = 0.9533, Reg Loss = 77.7465, Reconstruct Loss = 0.0000, Cls Loss = 0.9455, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.9899, Reg Loss = 80.8359, Reconstruct Loss = 0.0105, Cls Loss = 0.9713, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.9911, Reg Loss = 80.0213, Reconstruct Loss = 0.0106, Cls Loss = 0.9725, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.9786, Reg Loss = 78.4918, Reconstruct Loss = 0.0071, Cls Loss = 0.9637, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.9721, Reg Loss = 78.3093, Reconstruct Loss = 0.0089, Cls Loss = 0.9554, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.9732, Reg Loss = 77.9566, Reconstruct Loss = 0.0086, Cls Loss = 0.9568, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.9745, Reg Loss = 77.7486, Reconstruct Loss = 0.0115, Cls Loss = 0.9552, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 54.71it/s]


Epoch [16/50], Validation Loss: 0.8346, Validation Accuracy: 72.79%



Iteration 0: Loss = 0.8268, Reg Loss = 77.5219, Reconstruct Loss = 0.0000, Cls Loss = 0.8191, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.9185, Reg Loss = 78.1783, Reconstruct Loss = 0.0000, Cls Loss = 0.9107, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.9353, Reg Loss = 79.5333, Reconstruct Loss = 0.0107, Cls Loss = 0.9167, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.9376, Reg Loss = 79.5960, Reconstruct Loss = 0.0137, Cls Loss = 0.9160, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.9229, Reg Loss = 79.3679, Reconstruct Loss = 0.0132, Cls Loss = 0.9017, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.9202, Reg Loss = 79.1170, Reconstruct Loss = 0.0128, Cls Loss = 0.8995, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.9163, Reg Loss = 78.5704, Reconstruct Loss = 0.0131, Cls Loss = 0.8953, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.9096, Reg Loss = 78.1651, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:03<00:00, 25.85it/s]


Epoch [17/50], Validation Loss: 0.7790, Validation Accuracy: 77.14%



Checkpoint saved at epoch 16 with accuracy: 77.14%
Iteration 0: Loss = 0.9311, Reg Loss = 80.2649, Reconstruct Loss = 0.0000, Cls Loss = 0.9231, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.8518, Reg Loss = 77.4005, Reconstruct Loss = 0.0108, Cls Loss = 0.8332, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.8492, Reg Loss = 76.9578, Reconstruct Loss = 0.0198, Cls Loss = 0.8218, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.8461, Reg Loss = 76.7032, Reconstruct Loss = 0.0132, Cls Loss = 0.8252, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.8387, Reg Loss = 77.3415, Reconstruct Loss = 0.0162, Cls Loss = 0.8148, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.8278, Reg Loss = 77.3810, Reconstruct Loss = 0.0130, Cls Loss = 0.8071, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.8279, Reg Loss = 77.3043, Reconstruct Loss = 0.0142, Cls Loss = 0.8060, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:03<00:00, 24.55it/s]


Epoch [18/50], Validation Loss: 0.6731, Validation Accuracy: 79.66%



Checkpoint saved at epoch 17 with accuracy: 79.66%
Iteration 0: Loss = 1.1698, Reg Loss = 67.9531, Reconstruct Loss = 0.0000, Cls Loss = 1.1630, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.7692, Reg Loss = 75.8098, Reconstruct Loss = 0.0153, Cls Loss = 0.7463, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.7617, Reg Loss = 75.9635, Reconstruct Loss = 0.0138, Cls Loss = 0.7402, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.7617, Reg Loss = 75.3843, Reconstruct Loss = 0.0144, Cls Loss = 0.7398, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.7666, Reg Loss = 75.3170, Reconstruct Loss = 0.0186, Cls Loss = 0.7405, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.7642, Reg Loss = 75.3836, Reconstruct Loss = 0.0189, Cls Loss = 0.7377, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.7497, Reg Loss = 75.1986, Reconstruct Loss = 0.0174, Cls Loss = 0.7248, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:03<00:00, 25.61it/s]


Epoch [19/50], Validation Loss: 0.5659, Validation Accuracy: 82.57%



Checkpoint saved at epoch 18 with accuracy: 82.57%
Iteration 0: Loss = 0.8489, Reg Loss = 78.4439, Reconstruct Loss = 0.0000, Cls Loss = 0.8410, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.6639, Reg Loss = 75.7837, Reconstruct Loss = 0.0135, Cls Loss = 0.6429, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.6635, Reg Loss = 75.6298, Reconstruct Loss = 0.0097, Cls Loss = 0.6463, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.6617, Reg Loss = 75.7780, Reconstruct Loss = 0.0105, Cls Loss = 0.6436, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.6627, Reg Loss = 75.6393, Reconstruct Loss = 0.0106, Cls Loss = 0.6445, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.6629, Reg Loss = 75.5566, Reconstruct Loss = 0.0124, Cls Loss = 0.6429, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.6593, Reg Loss = 75.1832, Reconstruct Loss = 0.0103, Cls Loss = 0.6414, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:02<00:00, 26.82it/s]


Epoch [20/50], Validation Loss: 0.5069, Validation Accuracy: 85.22%



Checkpoint saved at epoch 19 with accuracy: 85.22%
Iteration 0: Loss = 0.5683, Reg Loss = 69.0108, Reconstruct Loss = 0.0000, Cls Loss = 0.5614, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.6348, Reg Loss = 75.4033, Reconstruct Loss = 0.0118, Cls Loss = 0.6154, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.6308, Reg Loss = 75.8486, Reconstruct Loss = 0.0115, Cls Loss = 0.6117, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.6192, Reg Loss = 75.6403, Reconstruct Loss = 0.0154, Cls Loss = 0.5962, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.6140, Reg Loss = 75.1585, Reconstruct Loss = 0.0169, Cls Loss = 0.5896, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.6126, Reg Loss = 74.8568, Reconstruct Loss = 0.0178, Cls Loss = 0.5873, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.6093, Reg Loss = 74.6950, Reconstruct Loss = 0.0166, Cls Loss = 0.5852, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:03<00:00, 24.88it/s]


Epoch [21/50], Validation Loss: 0.4686, Validation Accuracy: 86.40%



Checkpoint saved at epoch 20 with accuracy: 86.40%
Iteration 0: Loss = 0.7001, Reg Loss = 60.6686, Reconstruct Loss = 0.0000, Cls Loss = 0.6941, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.5934, Reg Loss = 70.7896, Reconstruct Loss = 0.0048, Cls Loss = 0.5815, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.5688, Reg Loss = 71.6947, Reconstruct Loss = 0.0048, Cls Loss = 0.5568, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.5515, Reg Loss = 72.2543, Reconstruct Loss = 0.0082, Cls Loss = 0.5360, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.5513, Reg Loss = 71.8376, Reconstruct Loss = 0.0062, Cls Loss = 0.5379, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.5507, Reg Loss = 71.8857, Reconstruct Loss = 0.0072, Cls Loss = 0.5363, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.5501, Reg Loss = 71.7766, Reconstruct Loss = 0.0077, Cls Loss = 0.5352, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 57.43it/s]


Epoch [22/50], Validation Loss: 0.4389, Validation Accuracy: 86.62%



Checkpoint saved at epoch 21 with accuracy: 86.62%
Iteration 0: Loss = 0.5600, Reg Loss = 77.7009, Reconstruct Loss = 0.0000, Cls Loss = 0.5522, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.5110, Reg Loss = 71.3886, Reconstruct Loss = 0.0053, Cls Loss = 0.4986, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.5105, Reg Loss = 71.2684, Reconstruct Loss = 0.0053, Cls Loss = 0.4981, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.5088, Reg Loss = 71.3861, Reconstruct Loss = 0.0069, Cls Loss = 0.4948, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.5106, Reg Loss = 71.4679, Reconstruct Loss = 0.0091, Cls Loss = 0.4943, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.5104, Reg Loss = 71.4403, Reconstruct Loss = 0.0083, Cls Loss = 0.4949, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.5114, Reg Loss = 71.4117, Reconstruct Loss = 0.0089, Cls Loss = 0.4954, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 57.26it/s]


Epoch [23/50], Validation Loss: 0.3692, Validation Accuracy: 89.96%



Checkpoint saved at epoch 22 with accuracy: 89.96%
Iteration 0: Loss = 0.4876, Reg Loss = 72.4189, Reconstruct Loss = 0.0000, Cls Loss = 0.4804, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.4996, Reg Loss = 70.1117, Reconstruct Loss = 0.0153, Cls Loss = 0.4773, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.4946, Reg Loss = 70.5766, Reconstruct Loss = 0.0205, Cls Loss = 0.4670, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.4814, Reg Loss = 70.3072, Reconstruct Loss = 0.0173, Cls Loss = 0.4571, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.4807, Reg Loss = 70.1020, Reconstruct Loss = 0.0163, Cls Loss = 0.4574, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.4805, Reg Loss = 70.1019, Reconstruct Loss = 0.0146, Cls Loss = 0.4589, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.4802, Reg Loss = 70.0051, Reconstruct Loss = 0.0128, Cls Loss = 0.4604, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 55.95it/s]


Epoch [24/50], Validation Loss: 0.3366, Validation Accuracy: 90.13%



Checkpoint saved at epoch 23 with accuracy: 90.13%
Iteration 0: Loss = 0.5035, Reg Loss = 72.1487, Reconstruct Loss = 0.2216, Cls Loss = 0.2747, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.4657, Reg Loss = 69.0437, Reconstruct Loss = 0.0131, Cls Loss = 0.4457, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.4495, Reg Loss = 69.4000, Reconstruct Loss = 0.0094, Cls Loss = 0.4331, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.4505, Reg Loss = 69.4698, Reconstruct Loss = 0.0094, Cls Loss = 0.4341, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.4443, Reg Loss = 69.5824, Reconstruct Loss = 0.0080, Cls Loss = 0.4293, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.4416, Reg Loss = 69.5599, Reconstruct Loss = 0.0075, Cls Loss = 0.4272, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.4425, Reg Loss = 69.3962, Reconstruct Loss = 0.0076, Cls Loss = 0.4280, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 58.18it/s]


Epoch [25/50], Validation Loss: 0.3164, Validation Accuracy: 90.88%



Checkpoint saved at epoch 24 with accuracy: 90.88%
Iteration 0: Loss = 0.4496, Reg Loss = 66.2635, Reconstruct Loss = 0.0000, Cls Loss = 0.4429, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.4057, Reg Loss = 69.2718, Reconstruct Loss = 0.0058, Cls Loss = 0.3931, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.4099, Reg Loss = 69.8949, Reconstruct Loss = 0.0079, Cls Loss = 0.3949, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.4080, Reg Loss = 69.5371, Reconstruct Loss = 0.0070, Cls Loss = 0.3941, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.4300, Reg Loss = 69.0923, Reconstruct Loss = 0.0103, Cls Loss = 0.4128, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.4287, Reg Loss = 68.9679, Reconstruct Loss = 0.0118, Cls Loss = 0.4101, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.4281, Reg Loss = 68.6160, Reconstruct Loss = 0.0098, Cls Loss = 0.4114, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 56.92it/s]


Epoch [26/50], Validation Loss: 0.2986, Validation Accuracy: 91.64%



Checkpoint saved at epoch 25 with accuracy: 91.64%
Iteration 0: Loss = 0.4414, Reg Loss = 65.1355, Reconstruct Loss = 0.0000, Cls Loss = 0.4349, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3953, Reg Loss = 67.4670, Reconstruct Loss = 0.0093, Cls Loss = 0.3792, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3932, Reg Loss = 67.1987, Reconstruct Loss = 0.0067, Cls Loss = 0.3798, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3953, Reg Loss = 67.3983, Reconstruct Loss = 0.0100, Cls Loss = 0.3786, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3963, Reg Loss = 67.6244, Reconstruct Loss = 0.0094, Cls Loss = 0.3802, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.4019, Reg Loss = 67.6319, Reconstruct Loss = 0.0092, Cls Loss = 0.3859, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3990, Reg Loss = 67.6446, Reconstruct Loss = 0.0084, Cls Loss = 0.3839, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 55.97it/s]


Epoch [27/50], Validation Loss: 0.2761, Validation Accuracy: 92.34%



Checkpoint saved at epoch 26 with accuracy: 92.34%
Iteration 0: Loss = 0.4027, Reg Loss = 67.8740, Reconstruct Loss = 0.0000, Cls Loss = 0.3959, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3595, Reg Loss = 67.2611, Reconstruct Loss = 0.0032, Cls Loss = 0.3496, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3800, Reg Loss = 67.9854, Reconstruct Loss = 0.0035, Cls Loss = 0.3697, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3757, Reg Loss = 68.1421, Reconstruct Loss = 0.0048, Cls Loss = 0.3641, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3712, Reg Loss = 68.5387, Reconstruct Loss = 0.0049, Cls Loss = 0.3595, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3713, Reg Loss = 68.6608, Reconstruct Loss = 0.0057, Cls Loss = 0.3586, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3732, Reg Loss = 68.4245, Reconstruct Loss = 0.0052, Cls Loss = 0.3612, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 57.50it/s]


Epoch [28/50], Validation Loss: 0.2812, Validation Accuracy: 91.78%



Iteration 0: Loss = 0.4038, Reg Loss = 72.6365, Reconstruct Loss = 0.0000, Cls Loss = 0.3966, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3831, Reg Loss = 69.0644, Reconstruct Loss = 0.0050, Cls Loss = 0.3713, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3813, Reg Loss = 68.5217, Reconstruct Loss = 0.0046, Cls Loss = 0.3698, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3747, Reg Loss = 68.5725, Reconstruct Loss = 0.0059, Cls Loss = 0.3619, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3707, Reg Loss = 68.8059, Reconstruct Loss = 0.0056, Cls Loss = 0.3583, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3706, Reg Loss = 68.7120, Reconstruct Loss = 0.0055, Cls Loss = 0.3583, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3675, Reg Loss = 68.5756, Reconstruct Loss = 0.0052, Cls Loss = 0.3554, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.3686, Reg Loss = 68.5384, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 58.62it/s]


Epoch [29/50], Validation Loss: 0.2771, Validation Accuracy: 92.14%



Iteration 0: Loss = 0.3359, Reg Loss = 70.7402, Reconstruct Loss = 0.0000, Cls Loss = 0.3288, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3515, Reg Loss = 68.1726, Reconstruct Loss = 0.0075, Cls Loss = 0.3372, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3643, Reg Loss = 68.2131, Reconstruct Loss = 0.0075, Cls Loss = 0.3500, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3713, Reg Loss = 68.4420, Reconstruct Loss = 0.0081, Cls Loss = 0.3564, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3663, Reg Loss = 68.4118, Reconstruct Loss = 0.0078, Cls Loss = 0.3517, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3686, Reg Loss = 68.1530, Reconstruct Loss = 0.0082, Cls Loss = 0.3535, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3666, Reg Loss = 68.0313, Reconstruct Loss = 0.0068, Cls Loss = 0.3530, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.3672, Reg Loss = 67.9926, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 57.80it/s]


Epoch [30/50], Validation Loss: 0.2644, Validation Accuracy: 92.51%



Checkpoint saved at epoch 29 with accuracy: 92.51%
Iteration 0: Loss = 0.3256, Reg Loss = 71.0495, Reconstruct Loss = 0.0000, Cls Loss = 0.3185, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3664, Reg Loss = 68.0919, Reconstruct Loss = 0.0070, Cls Loss = 0.3526, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3642, Reg Loss = 67.5287, Reconstruct Loss = 0.0082, Cls Loss = 0.3492, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3623, Reg Loss = 67.5349, Reconstruct Loss = 0.0096, Cls Loss = 0.3460, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3680, Reg Loss = 67.3567, Reconstruct Loss = 0.0085, Cls Loss = 0.3528, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3609, Reg Loss = 67.3315, Reconstruct Loss = 0.0074, Cls Loss = 0.3468, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3596, Reg Loss = 67.3443, Reconstruct Loss = 0.0077, Cls Loss = 0.3451, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 56.36it/s]


Epoch [31/50], Validation Loss: 0.2652, Validation Accuracy: 92.26%



Iteration 0: Loss = 0.4010, Reg Loss = 64.3212, Reconstruct Loss = 0.0000, Cls Loss = 0.3946, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3536, Reg Loss = 68.4525, Reconstruct Loss = 0.0035, Cls Loss = 0.3433, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3573, Reg Loss = 68.3075, Reconstruct Loss = 0.0076, Cls Loss = 0.3428, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3503, Reg Loss = 68.1370, Reconstruct Loss = 0.0051, Cls Loss = 0.3385, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3423, Reg Loss = 68.0896, Reconstruct Loss = 0.0055, Cls Loss = 0.3300, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3432, Reg Loss = 67.9193, Reconstruct Loss = 0.0050, Cls Loss = 0.3314, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3405, Reg Loss = 68.0311, Reconstruct Loss = 0.0048, Cls Loss = 0.3289, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.3403, Reg Loss = 68.0031, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 56.38it/s]


Epoch [32/50], Validation Loss: 0.2532, Validation Accuracy: 93.03%



Checkpoint saved at epoch 31 with accuracy: 93.03%
Iteration 0: Loss = 0.3181, Reg Loss = 62.0092, Reconstruct Loss = 0.0000, Cls Loss = 0.3119, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3404, Reg Loss = 67.5647, Reconstruct Loss = 0.0000, Cls Loss = 0.3337, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3192, Reg Loss = 67.7281, Reconstruct Loss = 0.0000, Cls Loss = 0.3124, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3274, Reg Loss = 67.8505, Reconstruct Loss = 0.0041, Cls Loss = 0.3166, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3264, Reg Loss = 67.9706, Reconstruct Loss = 0.0040, Cls Loss = 0.3156, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3240, Reg Loss = 68.2208, Reconstruct Loss = 0.0032, Cls Loss = 0.3140, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3302, Reg Loss = 68.2065, Reconstruct Loss = 0.0032, Cls Loss = 0.3202, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 55.99it/s]


Epoch [33/50], Validation Loss: 0.2437, Validation Accuracy: 92.99%



Iteration 0: Loss = 0.1909, Reg Loss = 73.3430, Reconstruct Loss = 0.0000, Cls Loss = 0.1836, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3119, Reg Loss = 69.3858, Reconstruct Loss = 0.0000, Cls Loss = 0.3050, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3255, Reg Loss = 68.9429, Reconstruct Loss = 0.0071, Cls Loss = 0.3116, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3280, Reg Loss = 68.5775, Reconstruct Loss = 0.0067, Cls Loss = 0.3145, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3258, Reg Loss = 68.3320, Reconstruct Loss = 0.0056, Cls Loss = 0.3133, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3226, Reg Loss = 68.4897, Reconstruct Loss = 0.0070, Cls Loss = 0.3087, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3221, Reg Loss = 68.6097, Reconstruct Loss = 0.0070, Cls Loss = 0.3082, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.3263, Reg Loss = 68.4724, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 57.15it/s]


Epoch [34/50], Validation Loss: 0.2421, Validation Accuracy: 93.17%



Checkpoint saved at epoch 33 with accuracy: 93.17%
Iteration 0: Loss = 0.2302, Reg Loss = 67.8178, Reconstruct Loss = 0.0000, Cls Loss = 0.2234, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.3193, Reg Loss = 68.2334, Reconstruct Loss = 0.0051, Cls Loss = 0.3074, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3177, Reg Loss = 68.2094, Reconstruct Loss = 0.0039, Cls Loss = 0.3070, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3139, Reg Loss = 68.3591, Reconstruct Loss = 0.0038, Cls Loss = 0.3033, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3120, Reg Loss = 68.6988, Reconstruct Loss = 0.0042, Cls Loss = 0.3010, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3125, Reg Loss = 68.7178, Reconstruct Loss = 0.0054, Cls Loss = 0.3002, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3124, Reg Loss = 68.5690, Reconstruct Loss = 0.0055, Cls Loss = 0.3000, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 56.61it/s]


Epoch [35/50], Validation Loss: 0.2412, Validation Accuracy: 93.03%



Iteration 0: Loss = 0.2731, Reg Loss = 71.4054, Reconstruct Loss = 0.0000, Cls Loss = 0.2659, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2915, Reg Loss = 70.1661, Reconstruct Loss = 0.0064, Cls Loss = 0.2781, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.3067, Reg Loss = 69.7964, Reconstruct Loss = 0.0072, Cls Loss = 0.2925, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.3061, Reg Loss = 69.1975, Reconstruct Loss = 0.0082, Cls Loss = 0.2909, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.3067, Reg Loss = 68.8646, Reconstruct Loss = 0.0074, Cls Loss = 0.2924, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.3059, Reg Loss = 68.9854, Reconstruct Loss = 0.0068, Cls Loss = 0.2923, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.3065, Reg Loss = 69.0020, Reconstruct Loss = 0.0071, Cls Loss = 0.2925, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.3025, Reg Loss = 69.0593, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 58.40it/s]


Epoch [36/50], Validation Loss: 0.2358, Validation Accuracy: 93.25%



Checkpoint saved at epoch 35 with accuracy: 93.25%
Iteration 0: Loss = 0.2855, Reg Loss = 70.7850, Reconstruct Loss = 0.0000, Cls Loss = 0.2784, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2815, Reg Loss = 70.5998, Reconstruct Loss = 0.0059, Cls Loss = 0.2686, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2875, Reg Loss = 70.1609, Reconstruct Loss = 0.0047, Cls Loss = 0.2758, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2922, Reg Loss = 70.1631, Reconstruct Loss = 0.0042, Cls Loss = 0.2810, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2965, Reg Loss = 70.0183, Reconstruct Loss = 0.0045, Cls Loss = 0.2850, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2905, Reg Loss = 70.0608, Reconstruct Loss = 0.0042, Cls Loss = 0.2793, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2928, Reg Loss = 69.8442, Reconstruct Loss = 0.0043, Cls Loss = 0.2815, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 58.16it/s]


Epoch [37/50], Validation Loss: 0.2337, Validation Accuracy: 93.36%



Checkpoint saved at epoch 36 with accuracy: 93.36%
Iteration 0: Loss = 0.2266, Reg Loss = 73.5127, Reconstruct Loss = 0.0000, Cls Loss = 0.2192, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2977, Reg Loss = 69.8954, Reconstruct Loss = 0.0076, Cls Loss = 0.2831, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2859, Reg Loss = 69.4830, Reconstruct Loss = 0.0049, Cls Loss = 0.2741, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2879, Reg Loss = 69.8601, Reconstruct Loss = 0.0033, Cls Loss = 0.2776, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2871, Reg Loss = 69.8897, Reconstruct Loss = 0.0040, Cls Loss = 0.2761, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2851, Reg Loss = 70.2056, Reconstruct Loss = 0.0038, Cls Loss = 0.2743, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2855, Reg Loss = 70.2409, Reconstruct Loss = 0.0061, Cls Loss = 0.2724, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 58.53it/s]


Epoch [38/50], Validation Loss: 0.2095, Validation Accuracy: 93.84%



Checkpoint saved at epoch 37 with accuracy: 93.84%
Iteration 0: Loss = 0.2434, Reg Loss = 71.0692, Reconstruct Loss = 0.0000, Cls Loss = 0.2363, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2900, Reg Loss = 70.6695, Reconstruct Loss = 0.0077, Cls Loss = 0.2752, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2943, Reg Loss = 71.1797, Reconstruct Loss = 0.0053, Cls Loss = 0.2819, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2870, Reg Loss = 71.3208, Reconstruct Loss = 0.0035, Cls Loss = 0.2763, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2838, Reg Loss = 71.1806, Reconstruct Loss = 0.0049, Cls Loss = 0.2719, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2826, Reg Loss = 71.0198, Reconstruct Loss = 0.0054, Cls Loss = 0.2701, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2835, Reg Loss = 70.7406, Reconstruct Loss = 0.0064, Cls Loss = 0.2700, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 56.69it/s]


Epoch [39/50], Validation Loss: 0.2206, Validation Accuracy: 93.53%



Iteration 0: Loss = 0.3662, Reg Loss = 65.9061, Reconstruct Loss = 0.0000, Cls Loss = 0.3596, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2644, Reg Loss = 69.6911, Reconstruct Loss = 0.0086, Cls Loss = 0.2488, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2838, Reg Loss = 68.9570, Reconstruct Loss = 0.0081, Cls Loss = 0.2689, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2808, Reg Loss = 68.4184, Reconstruct Loss = 0.0066, Cls Loss = 0.2673, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2776, Reg Loss = 68.2481, Reconstruct Loss = 0.0055, Cls Loss = 0.2653, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2771, Reg Loss = 68.2811, Reconstruct Loss = 0.0049, Cls Loss = 0.2654, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2765, Reg Loss = 68.5032, Reconstruct Loss = 0.0044, Cls Loss = 0.2652, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2754, Reg Loss = 68.8386, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 57.03it/s]


Epoch [40/50], Validation Loss: 0.2022, Validation Accuracy: 94.04%



Checkpoint saved at epoch 39 with accuracy: 94.04%
Iteration 0: Loss = 0.2845, Reg Loss = 63.5304, Reconstruct Loss = 0.0000, Cls Loss = 0.2781, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2646, Reg Loss = 70.2348, Reconstruct Loss = 0.0056, Cls Loss = 0.2519, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2617, Reg Loss = 70.4408, Reconstruct Loss = 0.0059, Cls Loss = 0.2487, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2562, Reg Loss = 70.2054, Reconstruct Loss = 0.0052, Cls Loss = 0.2440, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2610, Reg Loss = 70.1962, Reconstruct Loss = 0.0058, Cls Loss = 0.2482, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2647, Reg Loss = 70.2326, Reconstruct Loss = 0.0062, Cls Loss = 0.2515, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2673, Reg Loss = 70.2506, Reconstruct Loss = 0.0052, Cls Loss = 0.2551, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 57.73it/s]


Epoch [41/50], Validation Loss: 0.2037, Validation Accuracy: 94.36%



Checkpoint saved at epoch 40 with accuracy: 94.36%
Iteration 0: Loss = 0.4508, Reg Loss = 70.1484, Reconstruct Loss = 0.0000, Cls Loss = 0.4438, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2708, Reg Loss = 68.6780, Reconstruct Loss = 0.0018, Cls Loss = 0.2621, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2664, Reg Loss = 69.3378, Reconstruct Loss = 0.0030, Cls Loss = 0.2565, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2676, Reg Loss = 69.5209, Reconstruct Loss = 0.0038, Cls Loss = 0.2568, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2678, Reg Loss = 69.6677, Reconstruct Loss = 0.0029, Cls Loss = 0.2580, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2630, Reg Loss = 69.6267, Reconstruct Loss = 0.0031, Cls Loss = 0.2530, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2591, Reg Loss = 69.5455, Reconstruct Loss = 0.0042, Cls Loss = 0.2479, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 54.87it/s]


Epoch [42/50], Validation Loss: 0.1851, Validation Accuracy: 94.58%



Checkpoint saved at epoch 41 with accuracy: 94.58%
Iteration 0: Loss = 0.2851, Reg Loss = 72.6000, Reconstruct Loss = 0.0000, Cls Loss = 0.2778, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2509, Reg Loss = 72.4837, Reconstruct Loss = 0.0073, Cls Loss = 0.2364, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2568, Reg Loss = 71.1737, Reconstruct Loss = 0.0037, Cls Loss = 0.2460, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2599, Reg Loss = 70.6415, Reconstruct Loss = 0.0037, Cls Loss = 0.2491, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2584, Reg Loss = 70.4315, Reconstruct Loss = 0.0033, Cls Loss = 0.2480, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2577, Reg Loss = 70.3123, Reconstruct Loss = 0.0034, Cls Loss = 0.2472, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2566, Reg Loss = 70.3189, Reconstruct Loss = 0.0029, Cls Loss = 0.2467, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 57.46it/s]


Epoch [43/50], Validation Loss: 0.2000, Validation Accuracy: 94.10%



Iteration 0: Loss = 0.2668, Reg Loss = 75.1230, Reconstruct Loss = 0.0000, Cls Loss = 0.2593, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2324, Reg Loss = 72.4375, Reconstruct Loss = 0.0000, Cls Loss = 0.2252, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2461, Reg Loss = 71.5807, Reconstruct Loss = 0.0038, Cls Loss = 0.2352, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2425, Reg Loss = 71.3724, Reconstruct Loss = 0.0045, Cls Loss = 0.2309, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2422, Reg Loss = 71.1035, Reconstruct Loss = 0.0047, Cls Loss = 0.2304, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2439, Reg Loss = 70.9730, Reconstruct Loss = 0.0051, Cls Loss = 0.2317, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2442, Reg Loss = 70.7128, Reconstruct Loss = 0.0048, Cls Loss = 0.2323, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2464, Reg Loss = 70.7334, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 56.41it/s]


Epoch [44/50], Validation Loss: 0.2001, Validation Accuracy: 93.77%



Iteration 0: Loss = 0.1513, Reg Loss = 73.4253, Reconstruct Loss = 0.0000, Cls Loss = 0.1439, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2300, Reg Loss = 71.7829, Reconstruct Loss = 0.0032, Cls Loss = 0.2196, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2344, Reg Loss = 71.5295, Reconstruct Loss = 0.0023, Cls Loss = 0.2249, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2348, Reg Loss = 71.0981, Reconstruct Loss = 0.0039, Cls Loss = 0.2238, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2359, Reg Loss = 70.9907, Reconstruct Loss = 0.0042, Cls Loss = 0.2246, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2317, Reg Loss = 70.8469, Reconstruct Loss = 0.0040, Cls Loss = 0.2205, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2326, Reg Loss = 70.6096, Reconstruct Loss = 0.0039, Cls Loss = 0.2217, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2330, Reg Loss = 70.4998, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 57.20it/s]


Epoch [45/50], Validation Loss: 0.1682, Validation Accuracy: 95.01%



Checkpoint saved at epoch 44 with accuracy: 95.01%
Iteration 0: Loss = 0.2857, Reg Loss = 71.3004, Reconstruct Loss = 0.0000, Cls Loss = 0.2785, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2071, Reg Loss = 70.8065, Reconstruct Loss = 0.0014, Cls Loss = 0.1986, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2098, Reg Loss = 71.4228, Reconstruct Loss = 0.0023, Cls Loss = 0.2003, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2213, Reg Loss = 71.1432, Reconstruct Loss = 0.0021, Cls Loss = 0.2121, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2208, Reg Loss = 71.0972, Reconstruct Loss = 0.0022, Cls Loss = 0.2114, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2224, Reg Loss = 71.1308, Reconstruct Loss = 0.0024, Cls Loss = 0.2129, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2219, Reg Loss = 71.0818, Reconstruct Loss = 0.0020, Cls Loss = 0.2128, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 57.24it/s]


Epoch [46/50], Validation Loss: 0.1724, Validation Accuracy: 94.88%



Iteration 0: Loss = 0.2228, Reg Loss = 71.2874, Reconstruct Loss = 0.0731, Cls Loss = 0.1426, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2155, Reg Loss = 68.8987, Reconstruct Loss = 0.0055, Cls Loss = 0.2031, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2308, Reg Loss = 69.7682, Reconstruct Loss = 0.0051, Cls Loss = 0.2187, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2294, Reg Loss = 69.8294, Reconstruct Loss = 0.0039, Cls Loss = 0.2185, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2264, Reg Loss = 70.0183, Reconstruct Loss = 0.0032, Cls Loss = 0.2162, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2227, Reg Loss = 70.0777, Reconstruct Loss = 0.0029, Cls Loss = 0.2129, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2249, Reg Loss = 69.9987, Reconstruct Loss = 0.0024, Cls Loss = 0.2156, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2233, Reg Loss = 70.0915, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 54.91it/s]


Epoch [47/50], Validation Loss: 0.1744, Validation Accuracy: 94.59%



Iteration 0: Loss = 0.2412, Reg Loss = 69.3469, Reconstruct Loss = 0.0000, Cls Loss = 0.2343, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2118, Reg Loss = 70.2150, Reconstruct Loss = 0.0028, Cls Loss = 0.2020, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2171, Reg Loss = 69.7298, Reconstruct Loss = 0.0033, Cls Loss = 0.2068, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2178, Reg Loss = 69.4249, Reconstruct Loss = 0.0051, Cls Loss = 0.2058, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2192, Reg Loss = 69.4637, Reconstruct Loss = 0.0060, Cls Loss = 0.2063, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2202, Reg Loss = 69.3275, Reconstruct Loss = 0.0052, Cls Loss = 0.2080, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2214, Reg Loss = 69.2300, Reconstruct Loss = 0.0046, Cls Loss = 0.2099, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2240, Reg Loss = 69.2904, Reconstruct Loss = 0.

100%|██████████| 79/79 [00:01<00:00, 56.71it/s]


Epoch [48/50], Validation Loss: 0.1579, Validation Accuracy: 95.23%



Checkpoint saved at epoch 47 with accuracy: 95.23%
Iteration 0: Loss = 0.2193, Reg Loss = 69.5850, Reconstruct Loss = 0.0000, Cls Loss = 0.2124, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2356, Reg Loss = 70.1508, Reconstruct Loss = 0.0026, Cls Loss = 0.2260, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2369, Reg Loss = 70.4525, Reconstruct Loss = 0.0033, Cls Loss = 0.2266, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2312, Reg Loss = 70.6285, Reconstruct Loss = 0.0035, Cls Loss = 0.2206, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2254, Reg Loss = 70.5150, Reconstruct Loss = 0.0030, Cls Loss = 0.2154, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2252, Reg Loss = 70.3911, Reconstruct Loss = 0.0029, Cls Loss = 0.2153, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2233, Reg Loss = 70.4327, Reconstruct Loss = 0.0031, Cls Loss = 0.2132, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 53.10it/s]


Epoch [49/50], Validation Loss: 0.1551, Validation Accuracy: 95.29%



Checkpoint saved at epoch 48 with accuracy: 95.29%
Iteration 0: Loss = 0.1248, Reg Loss = 67.5277, Reconstruct Loss = 0.0000, Cls Loss = 0.1180, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2144, Reg Loss = 69.6593, Reconstruct Loss = 0.0012, Cls Loss = 0.2062, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2141, Reg Loss = 69.2963, Reconstruct Loss = 0.0018, Cls Loss = 0.2054, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2151, Reg Loss = 69.5883, Reconstruct Loss = 0.0024, Cls Loss = 0.2057, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2142, Reg Loss = 69.5703, Reconstruct Loss = 0.0028, Cls Loss = 0.2044, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2126, Reg Loss = 69.6513, Reconstruct Loss = 0.0025, Cls Loss = 0.2031, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2110, Reg Loss = 69.5766, Reconstruct Loss = 0.0031, Cls Loss = 0.2010, Learning rate = 1.0000e-03
Iteration 350: Loss 

100%|██████████| 79/79 [00:01<00:00, 55.60it/s]

Epoch [50/50], Validation Loss: 0.1527, Validation Accuracy: 95.09%








In [42]:
# Validate the model for the starting dimension (its pretrained form)
val_loss, acc = validate_single(dim_dict['8'][0], val_loader, nn.CrossEntropyLoss(), args=args)
print(f'Initial Permutated model Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc * 100:.2f}%')

100%|██████████| 79/79 [00:01<00:00, 51.09it/s]

Initial Permutated model Validation Loss: 0.2343, Validation Accuracy: 92.52%





In [43]:
# Validate the model for the starting dimension (its pretrained form)
val_loss, acc = validate_single(dim_dict['32'][0], val_loader, nn.CrossEntropyLoss(), args=args)
print(f'Initial Permutated model Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc * 100:.2f}%')

100%|██████████| 79/79 [00:01<00:00, 46.00it/s]

Initial Permutated model Validation Loss: 0.1670, Validation Accuracy: 95.00%





In [44]:
# Validate the model for the starting dimension (its pretrained form)
val_loss, acc = validate_single(gt_model_dict['32'], val_loader, nn.CrossEntropyLoss(), args=args)
print(f'Initial Permutated model Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc * 100:.2f}%')

100%|██████████| 79/79 [00:01<00:00, 47.00it/s]

Initial Permutated model Validation Loss: 0.0621, Validation Accuracy: 98.14%





In [45]:
# End the wandb tracking
wandb.finish()

0,1
Cls Loss,██▇▇▇▇▆▆▆▅▄▄▄▄▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁
Learning rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss,█▇▇▇█▇▇▇▆▆▅▅▄▄▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Reconstruct Loss,▁▂▃▃▁▃▄▄▄▃▄▁█▂▅▆▇▄▄▃▁▄▃▄▃▃▃▂▃▃▃▃▃▃▃▂▁▂▂▂
Reg Loss,▁█▇████████▇▇▇▇▇▇▇▇▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
Validation Accuracy,▁▁▁▂▂▂▃▄▄▄▅▅▆▆▆▇▇▇██████████████████████
Validation Loss,███▇▇▆▆▅▅▅▄▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Cls Loss,0.1986
Learning rate,0.001
Loss,0.20854
Reconstruct Loss,0.00297
Reg Loss,69.60761
Validation Accuracy,0.9509
Validation Loss,0.15271


### 7 Testing loop

In [48]:
saved_hypernet_path = args.training.save_model_path + 'mnist_nerf_best.pth'

In [49]:
saved_hypernet_path

'toy/experiments/mnist_lenet_8-32-noise/mnist_nerf_best.pth'

In [50]:
hyper_model_test = get_hypernetwork(args, number_param)

Hyper model type: resmlp
Using scalar 0.1
num_freqs:  16 <class 'int'>


In [51]:
checkpoint = torch.load(saved_hypernet_path, map_location="cpu")  # or "cuda" if using GPU
hyper_model_test.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [52]:
for hidden_dim in range(4, 48):
    # Create a model for this given dimension
    model = create_model(args.model.type,
                         hidden_dim=hidden_dim,
                         path=args.model.pretrained_path,
                         # smooth=args.model.smooth
                         ).to(device)
    
    # If EMA is specified, apply it
    if ema:
        print('Applying EMA')
        ema.apply()

    # Sample the merged model
    accumulated_model = sample_merge_model(hyper_model_test, model, args, K=100)

    # Validate the merged model
    val_loss, acc = validate_single(accumulated_model, val_loader, val_criterion, args=args)

    # If EMA is specified, restore the original weights after applying EMA
    if ema:
        ema.restore()  # Restore the original weights after applying 
        
    # Save the model
    save_name = os.path.join(args.training.save_model_path, f"mnist_{accumulated_model.__class__.__name__}_dim{hidden_dim}_single.pth")
    torch.save(accumulated_model.state_dict(),save_name)

    # Print the results
    print(f"Test using model {args.model}: hidden_dim {hidden_dim}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")
    print('\n')

    # Define the directory and filename structure
    filename = f"mnist_results_{args.experiment.name}.txt"
    filepath = os.path.join(args.training.save_model_path, filename)

    # Write the results. 'a' is used to append the results; a new file will be created if it doesn't exist.
    with open(filepath, "a") as file:
        file.write(f"Hidden_dim: {hidden_dim}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%\n")


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 68.04it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 4, Validation Loss: 1.4290, Validation Accuracy: 60.85%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 68.50it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 5, Validation Loss: 1.0547, Validation Accuracy: 70.98%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 67.39it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 6, Validation Loss: 0.9845, Validation Accuracy: 72.42%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 66.22it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 7, Validation Loss: 0.2909, Validation Accuracy: 91.31%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 64.53it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 8, Validation Loss: 0.2174, Validation Accuracy: 93.59%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 61.84it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 9, Validation Loss: 0.3409, Validation Accuracy: 89.95%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 59.26it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 10, Validation Loss: 0.2725, Validation Accuracy: 91.37%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 55.47it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 11, Validation Loss: 0.2486, Validation Accuracy: 92.92%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 55.09it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 12, Validation Loss: 0.1967, Validation Accuracy: 94.17%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 52.05it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 13, Validation Loss: 0.1894, Validation Accuracy: 94.27%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 52.31it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 14, Validation Loss: 0.2270, Validation Accuracy: 93.36%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 53.42it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 15, Validation Loss: 0.1856, Validation Accuracy: 94.59%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 52.75it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 16, Validation Loss: 0.2015, Validation Accuracy: 94.30%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 50.03it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 17, Validation Loss: 0.1769, Validation Accuracy: 94.80%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 67.70it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 18, Validation Loss: 0.1825, Validation Accuracy: 94.71%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 46.61it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 19, Validation Loss: 0.1751, Validation Accuracy: 94.89%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 68.19it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 20, Validation Loss: 0.1867, Validation Accuracy: 94.72%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 59.19it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 21, Validation Loss: 0.1895, Validation Accuracy: 94.50%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 65.68it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 22, Validation Loss: 0.1827, Validation Accuracy: 94.81%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 62.45it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 23, Validation Loss: 0.1945, Validation Accuracy: 94.37%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 65.52it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 24, Validation Loss: 0.1689, Validation Accuracy: 94.88%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 50.28it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 25, Validation Loss: 0.1736, Validation Accuracy: 94.92%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 50.35it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 26, Validation Loss: 0.1953, Validation Accuracy: 94.18%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 47.39it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 27, Validation Loss: 0.1883, Validation Accuracy: 94.59%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 49.41it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 28, Validation Loss: 0.1845, Validation Accuracy: 94.67%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 50.03it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 29, Validation Loss: 0.2140, Validation Accuracy: 94.04%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 48.59it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 30, Validation Loss: 0.1817, Validation Accuracy: 94.73%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 51.19it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 31, Validation Loss: 0.1744, Validation Accuracy: 94.96%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 49.47it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 32, Validation Loss: 0.1858, Validation Accuracy: 94.43%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 43.84it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 33, Validation Loss: 0.1866, Validation Accuracy: 94.78%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 46.07it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 34, Validation Loss: 0.1861, Validation Accuracy: 94.46%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 42.16it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 35, Validation Loss: 0.1685, Validation Accuracy: 95.04%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 45.33it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 36, Validation Loss: 0.1711, Validation Accuracy: 94.99%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 43.86it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 37, Validation Loss: 0.1733, Validation Accuracy: 95.03%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 45.73it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 38, Validation Loss: 0.1895, Validation Accuracy: 94.37%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 44.16it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 39, Validation Loss: 0.1790, Validation Accuracy: 94.68%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 44.97it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 40, Validation Loss: 0.1702, Validation Accuracy: 94.93%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 42.96it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 41, Validation Loss: 0.1857, Validation Accuracy: 94.47%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 44.49it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 42, Validation Loss: 0.1702, Validation Accuracy: 94.88%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 45.27it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 43, Validation Loss: 0.1814, Validation Accuracy: 94.75%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 43.87it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 44, Validation Loss: 0.1982, Validation Accuracy: 94.50%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 44.32it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 45, Validation Loss: 0.1774, Validation Accuracy: 94.93%


Applying EMA


100%|██████████| 79/79 [00:02<00:00, 38.85it/s]


Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 46, Validation Loss: 0.1801, Validation Accuracy: 95.00%


Applying EMA


100%|██████████| 79/79 [00:01<00:00, 40.79it/s]

Test using model {'type': 'LeNet', 'pretrained_path': 'toy/mnist_MnistNet_dim32.pth', 'smooth': False}: hidden_dim 47, Validation Loss: 0.1736, Validation Accuracy: 94.89%





