## Import

In [1]:
import os
import random

In [2]:
import torch
import torch.nn as nn

import wandb

from sklearn.metrics import accuracy_score

In [3]:
from neumeta.models import create_densenet_model as create_model
from neumeta.utils import (
    parse_args, print_omegaconf,
    load_checkpoint, save_checkpoint,
    set_seed,
    get_dataset,
    sample_coordinates, sample_subset, shuffle_coordinates_all,
    get_hypernetwork, get_optimizer,
    sample_weights,
    weighted_regression_loss, validate_single, AverageMeter, EMA,
    sample_merge_model
)

## Functions

### Find max dimension of the model

In [4]:
def find_max_dim(model_cls):
    """Find maximum dimension of the model"""
    # Get the learnable parameters of the model
    checkpoint = model_cls.learnable_parameter 

    # Set the maximum value to the length of the checkpoint
    max_value = len(checkpoint)

    # Iterate over the new model's weight
    for i, (k, tensor) in enumerate(checkpoint.items()):
        # Handle 2D tensors (e.g., weight matrices) 
        if len(tensor.shape) == 4:
            coords = [tensor.shape[0], tensor.shape[1]]
            max_value = max(max_value, max(coords))
        # Handle 1D tensors (e.g., biases)
        elif len(tensor.shape) == 1:
            max_value = max(max_value, tensor.shape[0])
    
    return max_value

### Initialize wandb

In [None]:
def initialize_wandb(config):
    import time
    """
    Initializes Weights and Biases (wandb) with the given configuration.
    
    Args:
        configuration (dict): Configuration parameters for the run.
    """
    # Name the run using current time and configuration name
    run_name = f"{time.strftime('%Y%m%d%H%M%S')}-{config.experiment.name}"
    
    wandb.init(project="dense-inr-trial", name=run_name, config=dict(config), group='cifar10')

### Init model dictionary

In [6]:
def init_model_dict(args, device):
    """
    Initializes a dictionary of models for each dimension in the given range, along with ground truth models for the starting dimension.

    Args:
        args: An object containing the arguments for initializing the models.

    Returns:
        dim_dict: A dictionary containing the models for each dimension, along with their corresponding coordinates, keys, indices, size, and ground truth models.
        gt_model_dict: A dictionary containing the ground truth models for the starting dimension.
    """
    dim_dict = {}
    gt_model_dict = {}
    
    # Create a model for each dimension in dimensions range
    for dim in args.dimensions.range:
        model_cls = create_model(args.model.type,
                                 layers=args.model.layers,
                                 growth=args.model.growth,
                                 compression=args.model.compression,
                                 bottleneck=args.model.bottleneck,
                                 drop_rate=args.model.drop_rate,
                                 hidden_dim=dim,
                                 path=args.model.pretrained_path).to(device)
        # Sample the coordinates, keys, indices, and the size for the model
        coords_tensor, keys_list, indices_list, size_list = sample_coordinates(model_cls)
        # Add the model, coordinates, keys, indices, size, and key mask to the dictionary
        dim_dict[f"{dim}"] = (model_cls, coords_tensor, keys_list, indices_list, size_list, None)

        # Print to makes line better
        print('\n')
        
        # If the dimension is the starting dimension (the dimension of pretrained_model), add the ground truth model to the dictionary
        if dim == args.dimensions.start:
            print(f"Loading model for dim {dim}")
            model_trained = create_model(args.model.type,
                                         layers=args.model.layers,
                                         growth=args.model.growth,
                                         compression=args.model.compression,
                                         bottleneck=args.model.bottleneck,
                                         drop_rate=args.model.drop_rate,
                                         path=args.model.pretrained_path,
                                         smooth=True,
                                         hidden_dim=dim).to(device)
            model_trained.eval()
            gt_model_dict[f'{dim}'] = model_trained

    
    return dim_dict, gt_model_dict

### Training function

In [7]:
# Function to train the model for one epoch
def train_one_epoch(model, train_loader, optimizer, criterion, dim_dict, gt_model_dict, epoch_idx, ema=None, args=None, device='cpu'):
    # Set the model to training mode
    model.train()
    total_loss = 0.0

    # Initialize AverageMeter objects to track the losses
    losses = AverageMeter()
    cls_losses = AverageMeter()
    reg_losses = AverageMeter()
    reconstruct_losses = AverageMeter()

    # Training accuracy
    preds = []
    gt = []

    # Iterate over the training data
    for batch_idx, (x, target) in enumerate(train_loader):
        # Zero the gradients
        optimizer.zero_grad()

        # Preprocess input
        # ------------------------------------------------------------------------------------------------------
        # Move the data to the device
        x, target = x.to(device), target.to(device)
        # Choose a random hidden dimension
        hidden_dim = random.choice(args.dimensions.range)
        # Get the model class, coordinates, keys, indices, size, and key mask for the chosen dimension
        model_cls, coords_tensor, keys_list, indices_list, size_list, key_mask = dim_dict[f"{hidden_dim}"]
        # Sample a subset the input tensor of the coordinates, keys, indices, size, and selected keys
        coords_tensor, keys_list, indices_list, size_list, selected_keys = sample_subset(coords_tensor,
                                                                                         keys_list,
                                                                                         indices_list,
                                                                                         size_list,
                                                                                         key_mask,
                                                                                         ratio=args.ratio)
        # Add noise to the coordinates if specified
        if args.training.coordinate_noise > 0.0:
            coords_tensor = coords_tensor + (torch.rand_like(coords_tensor) - 0.5) * args.training.coordinate_noise


        # Main task of hypernetwork and target network
        # ------------------------------------------------------------------------------------------------------
        # Sample the weights for the target model using hypernetwork
        model_cls, reconstructed_weights = sample_weights(model, model_cls,
                                                          coords_tensor, keys_list, indices_list, size_list, key_mask, selected_keys,
                                                          device=device, NORM=args.dimensions.norm)
        # Forward pass
        predict = model_cls(x)
        
        # Sample test model to see training accuracy

        pred = torch.argmax(predict, dim=-1)

        preds.append(pred)
        gt.append(target)

        # Compute losses
        # ------------------------------------------------------------------------------------------------------
        # Compute classification loss
        cls_loss = criterion(predict, target) 
        # Compute regularization loss
        reg_loss = sum([torch.norm(w, p=2) for w in reconstructed_weights])
        # Compute reconstruction loss if ground truth model is available
        if f"{hidden_dim}" in gt_model_dict:
            gt_model = gt_model_dict[f"{hidden_dim}"]
            gt_selected_weights = [
                w for k, w in gt_model.learnable_parameter.items() if k in selected_keys]

            reconstruct_loss = weighted_regression_loss(
                reconstructed_weights, gt_selected_weights)
        else:
            reconstruct_loss = torch.tensor(0.0)
        # Compute the total loss
        loss = args.hyper_model.loss_weight.ce_weight * cls_loss + args.hyper_model.loss_weight.reg_weight * \
            reg_loss + args.hyper_model.loss_weight.recon_weight * reconstruct_loss


        # Compute gradients and update weights
        # ------------------------------------------------------------------------------------------------------
        # Zero the gradients of the updated weights
        for updated_weight in model_cls.parameters():
            updated_weight.grad = None

        # Compute the gradients of the reconstructed weights
        loss.backward(retain_graph=True)
        torch.autograd.backward(reconstructed_weights, [
                                w.grad for k, w in model_cls.named_parameters() if k in selected_keys])
        
        # Clip the gradients if specified
        if args.training.get('clip_grad', 0.0) > 0:
            torch.nn.utils.clip_grad_value_(
                model.parameters(), args.training.clip_grad)
            
        # Update the weights
        optimizer.step()

        # Update the EMA if specified
        if ema:
            ema.update()  # Update the EMA after each training step
        total_loss += loss.item()

        # Update the AverageMeter objects
        losses.update(loss.item())
        cls_losses.update(cls_loss.item())
        reg_losses.update(reg_loss.item())
        reconstruct_losses.update(reconstruct_loss.item())

        # Log (or plot) losses
        # ------------------------------------------------------------------------------------------------------
        # Log the losses and learning rate to wandb
        if batch_idx % args.experiment.log_interval == 0:
            wandb.log({
                "Loss": losses.avg,
                "Cls Loss": cls_losses.avg,
                "Reg Loss": reg_losses.avg,
                "Reconstruct Loss": reconstruct_losses.avg,
                "Learning rate": optimizer.param_groups[0]['lr']
            }, step=batch_idx + epoch_idx * len(train_loader))
            # Print the losses and learning rate
            print(
                f"Iteration {batch_idx}: Loss = {losses.avg:.4f}, Reg Loss = {reg_losses.avg:.4f}, Reconstruct Loss = {reconstruct_losses.avg:.4f}, Cls Loss = {cls_losses.avg:.4f}, Learning rate = {optimizer.param_groups[0]['lr']:.4e}")
    
    train_acc = accuracy_score(torch.cat(gt).cpu().numpy(), torch.cat(preds).cpu().numpy())

    wandb.log({
        "Training accuracy": train_acc
    })

    # Returns the training loss, structure of network in each dimension, and the original structure of pretrained network
    return losses.avg, dim_dict, gt_model_dict, train_acc

## Main

### 0 Set device to GPU

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 1 Parsing arguments for input

In [None]:
CONFIG_PATH = 'neumeta/config/densenet_inr_train/dense_21th_experiment.yaml'
RATIO = '1.0'
CHECKPOINT_PATH = 'toy/experiments_densenet/dense_21th_experiment/cifar10_nerf_best.pth'

In [10]:
argv_train = ['--config', CONFIG_PATH, '--ratio', RATIO, '--resume_from', CHECKPOINT_PATH]

In [11]:
args = parse_args(argv_train)  # Parse arguments
print_omegaconf(args)  # Print arguments

+--------------------------------------+---------------------------------------------------------------------------------------------------+
|                 Key                  |                                               Value                                               |
+--------------------------------------+---------------------------------------------------------------------------------------------------+
|           experiment.name            |                    densenet_train_36_48_mlp_256_4_coordnoise_unsmooth_50e_bs64                    |
|        experiment.num_epochs         |                                                 50                                                |
|       experiment.log_interval        |                                                 50                                                |
|       experiment.eval_interval       |                                                 1                                                 |
|           e

In [12]:
set_seed(args.experiment.seed)

Setting seed... 42 for reproducibility


### 2 Get training and validation dataloader

In [13]:
train_loader, val_loader = get_dataset('cifar10', args.training.batch_size, strong_transform=args.training.get('strong_aug', None))

Using dataset: cifar10 with batch size: 64 and strong transform: None


### 3 Create target model

#### 3.0 Create the model

In [14]:
model = create_model(args.model.type,
                     layers=args.model.layers,
                     growth=args.model.growth,
                     compression=args.model.compression,
                     bottleneck=args.model.bottleneck,
                     drop_rate=args.model.drop_rate,
                     hidden_dim=args.dimensions.start,
                     path=args.model.pretrained_path).to(device)

Loading model from toy/experiments/densenet_bc_40_12_baseline/densenet_bc_40_12_cifar10_baseline_best.pth


#### 3.1 Print the structure and shape of the model

In [15]:
model

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (2): BottleneckBlock(
        (bn1): Bat

In [16]:
for i, (k, tensor) in enumerate(model.learnable_parameter.items()):
    print(k, tensor.shape)

block3.layer.5.conv1.weight torch.Size([48, 120, 1, 1])
block3.layer.5.conv1.bias torch.Size([48])
block3.layer.5.conv2.weight torch.Size([12, 48, 3, 3])


In [17]:
# Print the maximum dimension of the model
print(f'Maximum DIM: {find_max_dim(model)}')

Maximum DIM: 120


#### 3.2 Validate the accuracy of pretrained model

In [18]:
# Validate the model for the starting dimension (its pretrained form)
val_loss, acc = validate_single(model, val_loader, nn.CrossEntropyLoss(), args=args)
print(f'Initial Permutated model Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc * 100:.2f}%')

100%|██████████| 157/157 [00:03<00:00, 41.96it/s]

Initial Permutated model Validation Loss: 0.3239, Validation Accuracy: 91.93%





In [19]:
checkpoint = model.learnable_parameter
number_param = len(checkpoint)

In [20]:
# Print the keys of the parameters and the number of parameters
print(f"Parameters keys: {model.keys}")
print(f"Number of parameters to be learned: {number_param}")

Parameters keys: ['block3.layer.5.conv1.weight', 'block3.layer.5.conv1.bias', 'block3.layer.5.conv2.weight']
Number of parameters to be learned: 3


### 4 Create hypernetwork

#### 4.0 Create the model

In [21]:
# Get the hypermodel
hyper_model = get_hypernetwork(args, number_param)

Hyper model type: resmlp
Using scalar 0.1
num_freqs:  16 <class 'int'>


#### 4.1 Print model structure

In [22]:
hyper_model

NeRF_ResMLP_Compose(
  (positional_encoding): PositionalEncoding()
  (model): ModuleList(
    (0-2): 3 x NeRF_MLP_Residual_Scaled(
      (initial_layer): Linear(in_features=198, out_features=128, bias=True)
      (residual_blocks): ModuleList(
        (0-3): 4 x Linear(in_features=128, out_features=128, bias=True)
      )
      (scalars): ParameterList(
          (0): Parameter containing: [torch.float32 of size  (cuda:0)]
          (1): Parameter containing: [torch.float32 of size  (cuda:0)]
          (2): Parameter containing: [torch.float32 of size  (cuda:0)]
          (3): Parameter containing: [torch.float32 of size  (cuda:0)]
      )
      (act): ReLU(inplace=True)
      (output_layer): Linear(in_features=128, out_features=9, bias=True)
    )
  )
)

#### 4.2 Initialize EMA to track only a smooth version of the model weight

In [23]:
# Initialize the EMA
ema = EMA(hyper_model, decay=args.hyper_model.ema_decay)

### 5 Get loss function, optimizer and scheduler

In [24]:
criterion, val_criterion, optimizer, scheduler = get_optimizer(args, hyper_model)

In [25]:
print(f'Criterion: {criterion}\nVal_criterion: {val_criterion}\nOptimizer: {optimizer}\nScheduler: {scheduler}')

Criterion: CrossEntropyLoss()
Val_criterion: CrossEntropyLoss()
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)
Scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x000001AD2A6D4490>


### 6 Training loop

#### 6.1 Initialize training parameters

In [26]:
# Initialize the starting epoch and best accuracy
start_epoch = 0
best_acc = 0.0

#### 6.2 Directory to save the model

In [27]:
# Create the directory to save the model
os.makedirs(args.training.save_model_path, exist_ok=True)

#### 6.3 Resume training loop

In [28]:
args.resume_from

'toy/experiments/densenet_train_36_48_mlp_256_4_coordnoise_unsmooth_50e_bs64/cifar10_nerf_best.pth'

In [29]:
args.resume_from = False

In [30]:
if args.resume_from:
        print(f"Resuming from checkpoint: {args.resume_from}")
        checkpoint_info = load_checkpoint(args.resume_from, hyper_model, optimizer, ema)
        start_epoch = checkpoint_info['epoch']
        best_acc = checkpoint_info['best_acc']
        print(f"Resuming from epoch: {start_epoch}, best accuracy: {best_acc*100:.2f}%")
        # Note: If there are more elements to retrieve, do so here.

#### 6.4 Initialize model dictionary for each dimension and shuffle it

In [31]:
# Initialize model dictionary
dim_dict, gt_model_dict = init_model_dict(args, device)

Loading model from toy/experiments/densenet_bc_40_12_baseline/densenet_bc_40_12_cifar10_baseline_best.pth


Loading model for dim 48
Loading model from toy/experiments/densenet_bc_40_12_baseline/densenet_bc_40_12_cifar10_baseline_best.pth
Smooth the parameters of the model
Old TV original model: 428.7051086425781
Permuted TV original model: 394.62835693359375


In [32]:
gt_model_dict['48']

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (2): BottleneckBlock(
        (bn1): Bat

In [33]:
# Validate the model for the starting dimension (its pretrained form)
val_loss, acc = validate_single(gt_model_dict['48'], val_loader, nn.CrossEntropyLoss(), args=args)
print(f'Initial Permutated model Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc * 100:.2f}%')

100%|██████████| 157/157 [00:03<00:00, 44.81it/s]

Initial Permutated model Validation Loss: 0.3239, Validation Accuracy: 91.94%





In [34]:
dim_dict

{'48': (DenseNet3(
    (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (block1): DenseBlock(
      (layer): Sequential(
        (0): BottleneckBlock(
          (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
          (bn2): Identity()
          (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): BottleneckBlock(
          (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
          (bn2): Identity()
          (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )


In [35]:
dim_dict = shuffle_coordinates_all(dim_dict)
dim_dict

{'48': (DenseNet3(
    (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (block1): DenseBlock(
      (layer): Sequential(
        (0): BottleneckBlock(
          (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
          (bn2): Identity()
          (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (1): BottleneckBlock(
          (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
          (bn2): Identity()
          (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )


#### 6.5 Initialize wandb for plotting

In [36]:
initialize_wandb(args)

[34m[1mwandb[0m: Currently logged in as: [33mefradosuryadi[0m ([33mefradosuryadi-universitas-indonesia[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


#### 6.6 Hypernetwork training loop

In [37]:
args.experiment.num_epochs

50

In [38]:
# Iterate over the epochs
for epoch in range(start_epoch, args.experiment.num_epochs):
    # Train the hypernetwork to generate a model with random dimension for one epoch
    train_loss, dim_dict, gt_model_dict, train_acc = train_one_epoch(hyper_model, train_loader, optimizer, criterion, 
                                                                     dim_dict, gt_model_dict, epoch_idx=epoch, ema=ema, 
                                                                     args=args, device=device)
    # Step the scheduler
    scheduler.step()

    # Print the training loss and learning rate
    print(f"Epoch [{epoch+1}/{args.experiment.num_epochs}], Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc*100:.2f}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}")

    # If it's time to evaluate the model
    if (epoch + 1) % args.experiment.eval_interval == 0:
        # Apply EMA if it is specified
        if ema:
            ema.apply()  # Save the weights of original model created before training_loop
        
        # Sample the merged model (create model of same structure before training loop by using the hypernetwork)
        # And then test the performance of the hypernetwork by seeing how good it is in generating the weights
        model = sample_merge_model(hyper_model, model, args) 
        # Validate the merged model
        val_loss, acc = validate_single(model, val_loader, val_criterion, args=args)

        # If EMA is specified, restore the original weights
        if ema:
            ema.restore()  # Restore the original weights to the weights of the pretrained networks

        # Log the validation loss and accuracy to wandb
        wandb.log({
            "Validation Loss": val_loss,
            "Validation Accuracy": acc
        })
        # Print the validation loss and accuracy
        print(f"Epoch [{epoch+1}/{args.experiment.num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")
        print('\n\n')

        # Save the checkpoint if the accuracy is better than the previous best
        if acc > best_acc:
            best_acc = acc
            save_checkpoint(f"{args.training.save_model_path}/cifar10_nerf_best.pth",hyper_model,optimizer,ema,epoch,best_acc)
            print(f"Checkpoint saved at epoch {epoch} with accuracy: {best_acc*100:.2f}%")


Iteration 0: Loss = 0.4199, Reg Loss = 1.7900, Reconstruct Loss = 0.0359, Cls Loss = 0.4163, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2422, Reg Loss = 6.5068, Reconstruct Loss = 0.1154, Cls Loss = 0.2306, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2397, Reg Loss = 5.9539, Reconstruct Loss = 0.0980, Cls Loss = 0.2299, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2359, Reg Loss = 5.3496, Reconstruct Loss = 0.0831, Cls Loss = 0.2275, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2359, Reg Loss = 5.0005, Reconstruct Loss = 0.0728, Cls Loss = 0.2286, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2336, Reg Loss = 4.7687, Reconstruct Loss = 0.0671, Cls Loss = 0.2269, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2319, Reg Loss = 4.6707, Reconstruct Loss = 0.0630, Cls Loss = 0.2256, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2315, Reg Loss = 4.5630, Reconstruct Loss = 0.0596, Cls Loss = 0.2255, Learning rate = 1.0000e-03
Iteration 400: Loss = 0.231

100%|██████████| 157/157 [00:03<00:00, 43.20it/s]


Epoch [1/50], Validation Loss: 1.4880, Validation Accuracy: 60.90%



Checkpoint saved at epoch 0 with accuracy: 60.90%
Iteration 0: Loss = 0.1565, Reg Loss = 9.5124, Reconstruct Loss = 0.1594, Cls Loss = 0.1405, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2335, Reg Loss = 7.1058, Reconstruct Loss = 0.0895, Cls Loss = 0.2245, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2257, Reg Loss = 7.5040, Reconstruct Loss = 0.0722, Cls Loss = 0.2184, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2254, Reg Loss = 7.8336, Reconstruct Loss = 0.0684, Cls Loss = 0.2185, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2214, Reg Loss = 7.5744, Reconstruct Loss = 0.0632, Cls Loss = 0.2150, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2212, Reg Loss = 7.1430, Reconstruct Loss = 0.0582, Cls Loss = 0.2153, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2203, Reg Loss = 6.9225, Reconstruct Loss = 0.0560, Cls Loss = 0.2146, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2203,

100%|██████████| 157/157 [00:03<00:00, 43.34it/s]


Epoch [2/50], Validation Loss: 1.4482, Validation Accuracy: 61.41%



Checkpoint saved at epoch 1 with accuracy: 61.41%
Iteration 0: Loss = 0.1785, Reg Loss = 6.9306, Reconstruct Loss = 0.0593, Cls Loss = 0.1725, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.2091, Reg Loss = 6.9708, Reconstruct Loss = 0.0658, Cls Loss = 0.2024, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.2055, Reg Loss = 6.8291, Reconstruct Loss = 0.0540, Cls Loss = 0.2000, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.2066, Reg Loss = 6.5154, Reconstruct Loss = 0.0523, Cls Loss = 0.2013, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.2054, Reg Loss = 6.4983, Reconstruct Loss = 0.0496, Cls Loss = 0.2004, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.2057, Reg Loss = 6.3105, Reconstruct Loss = 0.0470, Cls Loss = 0.2009, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.2025, Reg Loss = 6.2882, Reconstruct Loss = 0.0457, Cls Loss = 0.1978, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.2010,

100%|██████████| 157/157 [00:04<00:00, 36.23it/s]


Epoch [3/50], Validation Loss: 1.4584, Validation Accuracy: 61.34%



Iteration 0: Loss = 0.1734, Reg Loss = 4.9650, Reconstruct Loss = 0.0353, Cls Loss = 0.1698, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1867, Reg Loss = 5.0048, Reconstruct Loss = 0.0401, Cls Loss = 0.1826, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1875, Reg Loss = 5.9744, Reconstruct Loss = 0.0382, Cls Loss = 0.1837, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1874, Reg Loss = 6.3277, Reconstruct Loss = 0.0369, Cls Loss = 0.1837, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1886, Reg Loss = 6.4871, Reconstruct Loss = 0.0386, Cls Loss = 0.1847, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1870, Reg Loss = 6.4372, Reconstruct Loss = 0.0381, Cls Loss = 0.1831, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1879, Reg Loss = 6.3617, Reconstruct Loss = 0.0396, Cls Loss = 0.1838, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1884, Reg Loss = 6.2992, Reconstruct Loss = 0.0392, Cls

100%|██████████| 157/157 [00:04<00:00, 35.29it/s]


Epoch [4/50], Validation Loss: 1.5025, Validation Accuracy: 60.66%



Iteration 0: Loss = 0.2825, Reg Loss = 6.4510, Reconstruct Loss = 0.0236, Cls Loss = 0.2800, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1641, Reg Loss = 6.5973, Reconstruct Loss = 0.0254, Cls Loss = 0.1615, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1680, Reg Loss = 7.3025, Reconstruct Loss = 0.0324, Cls Loss = 0.1647, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1688, Reg Loss = 7.7567, Reconstruct Loss = 0.0379, Cls Loss = 0.1649, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1714, Reg Loss = 7.7500, Reconstruct Loss = 0.0383, Cls Loss = 0.1675, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1708, Reg Loss = 7.7048, Reconstruct Loss = 0.0382, Cls Loss = 0.1669, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1703, Reg Loss = 7.6106, Reconstruct Loss = 0.0382, Cls Loss = 0.1664, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1707, Reg Loss = 7.5283, Reconstruct Loss = 0.0380, Cls

100%|██████████| 157/157 [00:04<00:00, 35.64it/s]


Epoch [5/50], Validation Loss: 1.4285, Validation Accuracy: 62.01%



Checkpoint saved at epoch 4 with accuracy: 62.01%
Iteration 0: Loss = 0.2319, Reg Loss = 7.4496, Reconstruct Loss = 0.0228, Cls Loss = 0.2296, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1647, Reg Loss = 7.9139, Reconstruct Loss = 0.0456, Cls Loss = 0.1601, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1710, Reg Loss = 7.7706, Reconstruct Loss = 0.0438, Cls Loss = 0.1665, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1719, Reg Loss = 7.5936, Reconstruct Loss = 0.0438, Cls Loss = 0.1674, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1718, Reg Loss = 7.5298, Reconstruct Loss = 0.0420, Cls Loss = 0.1675, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1700, Reg Loss = 7.4261, Reconstruct Loss = 0.0407, Cls Loss = 0.1658, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1690, Reg Loss = 7.3555, Reconstruct Loss = 0.0394, Cls Loss = 0.1650, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1692,

100%|██████████| 157/157 [00:04<00:00, 36.44it/s]


Epoch [6/50], Validation Loss: 1.4395, Validation Accuracy: 66.20%



Checkpoint saved at epoch 5 with accuracy: 66.20%
Iteration 0: Loss = 0.2190, Reg Loss = 7.0245, Reconstruct Loss = 0.0333, Cls Loss = 0.2156, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1679, Reg Loss = 7.6529, Reconstruct Loss = 0.0320, Cls Loss = 0.1646, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1628, Reg Loss = 7.5069, Reconstruct Loss = 0.0324, Cls Loss = 0.1595, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1628, Reg Loss = 7.3067, Reconstruct Loss = 0.0335, Cls Loss = 0.1594, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1605, Reg Loss = 7.2519, Reconstruct Loss = 0.0338, Cls Loss = 0.1570, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1608, Reg Loss = 7.2631, Reconstruct Loss = 0.0333, Cls Loss = 0.1573, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1587, Reg Loss = 7.2863, Reconstruct Loss = 0.0336, Cls Loss = 0.1553, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1581,

100%|██████████| 157/157 [00:04<00:00, 36.12it/s]


Epoch [7/50], Validation Loss: 1.4902, Validation Accuracy: 68.50%



Checkpoint saved at epoch 6 with accuracy: 68.50%
Iteration 0: Loss = 0.2469, Reg Loss = 6.0895, Reconstruct Loss = 0.0509, Cls Loss = 0.2418, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1644, Reg Loss = 6.5459, Reconstruct Loss = 0.0372, Cls Loss = 0.1607, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1577, Reg Loss = 7.0021, Reconstruct Loss = 0.0381, Cls Loss = 0.1538, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1547, Reg Loss = 7.0699, Reconstruct Loss = 0.0368, Cls Loss = 0.1510, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1534, Reg Loss = 7.1596, Reconstruct Loss = 0.0363, Cls Loss = 0.1497, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1550, Reg Loss = 7.2139, Reconstruct Loss = 0.0358, Cls Loss = 0.1514, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1561, Reg Loss = 7.1443, Reconstruct Loss = 0.0359, Cls Loss = 0.1525, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1552,

100%|██████████| 157/157 [00:03<00:00, 43.24it/s]


Epoch [8/50], Validation Loss: 2.3154, Validation Accuracy: 55.65%



Iteration 0: Loss = 0.1141, Reg Loss = 8.1294, Reconstruct Loss = 0.0318, Cls Loss = 0.1108, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1437, Reg Loss = 8.1876, Reconstruct Loss = 0.0359, Cls Loss = 0.1400, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1350, Reg Loss = 8.6424, Reconstruct Loss = 0.0329, Cls Loss = 0.1317, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1316, Reg Loss = 9.1379, Reconstruct Loss = 0.0328, Cls Loss = 0.1282, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1302, Reg Loss = 9.2481, Reconstruct Loss = 0.0334, Cls Loss = 0.1268, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1309, Reg Loss = 9.1750, Reconstruct Loss = 0.0337, Cls Loss = 0.1275, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1305, Reg Loss = 9.2185, Reconstruct Loss = 0.0339, Cls Loss = 0.1270, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1306, Reg Loss = 9.1390, Reconstruct Loss = 0.0341, Cls

100%|██████████| 157/157 [00:04<00:00, 33.38it/s]


Epoch [9/50], Validation Loss: 7.9931, Validation Accuracy: 22.33%



Iteration 0: Loss = 0.0903, Reg Loss = 9.1710, Reconstruct Loss = 0.0445, Cls Loss = 0.0858, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1246, Reg Loss = 9.8249, Reconstruct Loss = 0.0355, Cls Loss = 0.1210, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1202, Reg Loss = 9.8165, Reconstruct Loss = 0.0386, Cls Loss = 0.1163, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1207, Reg Loss = 9.8800, Reconstruct Loss = 0.0407, Cls Loss = 0.1165, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1224, Reg Loss = 10.0822, Reconstruct Loss = 0.0395, Cls Loss = 0.1183, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1209, Reg Loss = 10.3821, Reconstruct Loss = 0.0391, Cls Loss = 0.1169, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1197, Reg Loss = 10.4945, Reconstruct Loss = 0.0386, Cls Loss = 0.1157, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1204, Reg Loss = 10.5370, Reconstruct Loss = 0.0379,

100%|██████████| 157/157 [00:04<00:00, 33.67it/s]


Epoch [10/50], Validation Loss: 9.5388, Validation Accuracy: 21.35%



Iteration 0: Loss = 0.1109, Reg Loss = 10.0591, Reconstruct Loss = 0.0477, Cls Loss = 0.1060, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.1104, Reg Loss = 11.0659, Reconstruct Loss = 0.0375, Cls Loss = 0.1066, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.1109, Reg Loss = 11.6304, Reconstruct Loss = 0.0375, Cls Loss = 0.1071, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.1126, Reg Loss = 12.3156, Reconstruct Loss = 0.0366, Cls Loss = 0.1088, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.1117, Reg Loss = 12.2741, Reconstruct Loss = 0.0366, Cls Loss = 0.1079, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.1097, Reg Loss = 12.1598, Reconstruct Loss = 0.0365, Cls Loss = 0.1059, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.1104, Reg Loss = 12.1495, Reconstruct Loss = 0.0363, Cls Loss = 0.1067, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.1096, Reg Loss = 12.1328, Reconstruct Loss = 0.

100%|██████████| 157/157 [00:04<00:00, 32.57it/s]


Epoch [11/50], Validation Loss: 8.5594, Validation Accuracy: 22.52%



Iteration 0: Loss = 0.0761, Reg Loss = 11.2656, Reconstruct Loss = 0.0398, Cls Loss = 0.0721, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0967, Reg Loss = 11.4373, Reconstruct Loss = 0.0384, Cls Loss = 0.0927, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0958, Reg Loss = 12.1807, Reconstruct Loss = 0.0385, Cls Loss = 0.0918, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0942, Reg Loss = 12.1014, Reconstruct Loss = 0.0384, Cls Loss = 0.0902, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0963, Reg Loss = 12.1642, Reconstruct Loss = 0.0380, Cls Loss = 0.0924, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0983, Reg Loss = 12.3138, Reconstruct Loss = 0.0381, Cls Loss = 0.0944, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0982, Reg Loss = 12.2096, Reconstruct Loss = 0.0377, Cls Loss = 0.0943, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0970, Reg Loss = 12.0931, Reconstruct Loss = 0.

100%|██████████| 157/157 [00:04<00:00, 32.41it/s]


Epoch [12/50], Validation Loss: 4.3655, Validation Accuracy: 39.92%



Iteration 0: Loss = 0.1216, Reg Loss = 13.3130, Reconstruct Loss = 0.0362, Cls Loss = 0.1178, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0897, Reg Loss = 12.8028, Reconstruct Loss = 0.0329, Cls Loss = 0.0863, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0948, Reg Loss = 11.6395, Reconstruct Loss = 0.0328, Cls Loss = 0.0914, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0945, Reg Loss = 11.3305, Reconstruct Loss = 0.0337, Cls Loss = 0.0911, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0941, Reg Loss = 11.3179, Reconstruct Loss = 0.0342, Cls Loss = 0.0906, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0956, Reg Loss = 11.4335, Reconstruct Loss = 0.0343, Cls Loss = 0.0920, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0951, Reg Loss = 11.4407, Reconstruct Loss = 0.0341, Cls Loss = 0.0916, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0953, Reg Loss = 11.4210, Reconstruct Loss = 0.

100%|██████████| 157/157 [00:04<00:00, 33.56it/s]


Epoch [13/50], Validation Loss: 2.8987, Validation Accuracy: 58.67%



Iteration 0: Loss = 0.1091, Reg Loss = 10.7780, Reconstruct Loss = 0.0337, Cls Loss = 0.1056, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0886, Reg Loss = 11.2049, Reconstruct Loss = 0.0376, Cls Loss = 0.0847, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0925, Reg Loss = 10.6962, Reconstruct Loss = 0.0377, Cls Loss = 0.0886, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0911, Reg Loss = 10.6241, Reconstruct Loss = 0.0371, Cls Loss = 0.0873, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0905, Reg Loss = 10.5425, Reconstruct Loss = 0.0372, Cls Loss = 0.0867, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0906, Reg Loss = 10.4565, Reconstruct Loss = 0.0370, Cls Loss = 0.0868, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0910, Reg Loss = 10.4193, Reconstruct Loss = 0.0370, Cls Loss = 0.0872, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0894, Reg Loss = 10.4817, Reconstruct Loss = 0.

100%|██████████| 157/157 [00:04<00:00, 32.20it/s]


Epoch [14/50], Validation Loss: 3.8130, Validation Accuracy: 52.63%



Iteration 0: Loss = 0.0459, Reg Loss = 10.8438, Reconstruct Loss = 0.0317, Cls Loss = 0.0426, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0906, Reg Loss = 9.9746, Reconstruct Loss = 0.0321, Cls Loss = 0.0873, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0938, Reg Loss = 9.8265, Reconstruct Loss = 0.0336, Cls Loss = 0.0903, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0897, Reg Loss = 9.8930, Reconstruct Loss = 0.0339, Cls Loss = 0.0862, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0884, Reg Loss = 10.0226, Reconstruct Loss = 0.0343, Cls Loss = 0.0848, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0879, Reg Loss = 10.1111, Reconstruct Loss = 0.0346, Cls Loss = 0.0843, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0875, Reg Loss = 10.1662, Reconstruct Loss = 0.0345, Cls Loss = 0.0840, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0880, Reg Loss = 10.1917, Reconstruct Loss = 0.035

100%|██████████| 157/157 [00:04<00:00, 33.85it/s]


Epoch [15/50], Validation Loss: 4.0393, Validation Accuracy: 50.49%



Iteration 0: Loss = 0.0963, Reg Loss = 9.6251, Reconstruct Loss = 0.0457, Cls Loss = 0.0916, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0859, Reg Loss = 9.6408, Reconstruct Loss = 0.0355, Cls Loss = 0.0823, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0914, Reg Loss = 9.8560, Reconstruct Loss = 0.0329, Cls Loss = 0.0880, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0908, Reg Loss = 9.9580, Reconstruct Loss = 0.0324, Cls Loss = 0.0875, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0899, Reg Loss = 9.9340, Reconstruct Loss = 0.0330, Cls Loss = 0.0865, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0887, Reg Loss = 10.0142, Reconstruct Loss = 0.0330, Cls Loss = 0.0853, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0888, Reg Loss = 10.0569, Reconstruct Loss = 0.0322, Cls Loss = 0.0854, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0877, Reg Loss = 10.0015, Reconstruct Loss = 0.0317,

100%|██████████| 157/157 [00:04<00:00, 32.32it/s]


Epoch [16/50], Validation Loss: 2.8281, Validation Accuracy: 60.21%



Iteration 0: Loss = 0.1027, Reg Loss = 9.1541, Reconstruct Loss = 0.0410, Cls Loss = 0.0985, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0840, Reg Loss = 8.7340, Reconstruct Loss = 0.0313, Cls Loss = 0.0807, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0833, Reg Loss = 9.1506, Reconstruct Loss = 0.0302, Cls Loss = 0.0802, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0845, Reg Loss = 9.0991, Reconstruct Loss = 0.0306, Cls Loss = 0.0813, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0827, Reg Loss = 9.0983, Reconstruct Loss = 0.0312, Cls Loss = 0.0795, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0855, Reg Loss = 8.9953, Reconstruct Loss = 0.0314, Cls Loss = 0.0823, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0863, Reg Loss = 9.0816, Reconstruct Loss = 0.0313, Cls Loss = 0.0831, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0867, Reg Loss = 9.1543, Reconstruct Loss = 0.0308, Cl

100%|██████████| 157/157 [00:04<00:00, 32.50it/s]


Epoch [17/50], Validation Loss: 5.3849, Validation Accuracy: 43.46%



Iteration 0: Loss = 0.1004, Reg Loss = 10.3459, Reconstruct Loss = 0.0248, Cls Loss = 0.0978, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0937, Reg Loss = 10.6937, Reconstruct Loss = 0.0300, Cls Loss = 0.0906, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0895, Reg Loss = 10.8857, Reconstruct Loss = 0.0301, Cls Loss = 0.0863, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0862, Reg Loss = 10.5293, Reconstruct Loss = 0.0298, Cls Loss = 0.0831, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0849, Reg Loss = 10.1538, Reconstruct Loss = 0.0299, Cls Loss = 0.0818, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0851, Reg Loss = 10.0624, Reconstruct Loss = 0.0299, Cls Loss = 0.0820, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0861, Reg Loss = 9.9945, Reconstruct Loss = 0.0297, Cls Loss = 0.0831, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0867, Reg Loss = 9.9744, Reconstruct Loss = 0.03

100%|██████████| 157/157 [00:04<00:00, 33.25it/s]


Epoch [18/50], Validation Loss: 5.3948, Validation Accuracy: 44.81%



Iteration 0: Loss = 0.1000, Reg Loss = 9.0071, Reconstruct Loss = 0.0378, Cls Loss = 0.0961, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0862, Reg Loss = 9.1554, Reconstruct Loss = 0.0311, Cls Loss = 0.0830, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0817, Reg Loss = 9.2371, Reconstruct Loss = 0.0302, Cls Loss = 0.0786, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0838, Reg Loss = 9.4051, Reconstruct Loss = 0.0294, Cls Loss = 0.0808, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0844, Reg Loss = 9.3606, Reconstruct Loss = 0.0290, Cls Loss = 0.0814, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0850, Reg Loss = 9.4340, Reconstruct Loss = 0.0293, Cls Loss = 0.0819, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0850, Reg Loss = 9.5038, Reconstruct Loss = 0.0293, Cls Loss = 0.0820, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0844, Reg Loss = 9.5835, Reconstruct Loss = 0.0291, Cl

100%|██████████| 157/157 [00:04<00:00, 32.91it/s]


Epoch [19/50], Validation Loss: 5.9511, Validation Accuracy: 42.90%



Iteration 0: Loss = 0.0660, Reg Loss = 10.6118, Reconstruct Loss = 0.0215, Cls Loss = 0.0637, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0870, Reg Loss = 11.1892, Reconstruct Loss = 0.0282, Cls Loss = 0.0841, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0886, Reg Loss = 10.7937, Reconstruct Loss = 0.0297, Cls Loss = 0.0855, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0920, Reg Loss = 10.4922, Reconstruct Loss = 0.0301, Cls Loss = 0.0889, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0909, Reg Loss = 10.1723, Reconstruct Loss = 0.0301, Cls Loss = 0.0878, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0898, Reg Loss = 10.1425, Reconstruct Loss = 0.0295, Cls Loss = 0.0867, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0894, Reg Loss = 10.0047, Reconstruct Loss = 0.0298, Cls Loss = 0.0863, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0889, Reg Loss = 9.9198, Reconstruct Loss = 0.0

100%|██████████| 157/157 [00:04<00:00, 33.05it/s]


Epoch [20/50], Validation Loss: 7.0488, Validation Accuracy: 42.25%



Iteration 0: Loss = 0.0400, Reg Loss = 9.2674, Reconstruct Loss = 0.0339, Cls Loss = 0.0365, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0802, Reg Loss = 9.6478, Reconstruct Loss = 0.0278, Cls Loss = 0.0773, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0823, Reg Loss = 10.1281, Reconstruct Loss = 0.0286, Cls Loss = 0.0793, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0854, Reg Loss = 10.1767, Reconstruct Loss = 0.0300, Cls Loss = 0.0823, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0864, Reg Loss = 9.9936, Reconstruct Loss = 0.0312, Cls Loss = 0.0832, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0873, Reg Loss = 9.8845, Reconstruct Loss = 0.0329, Cls Loss = 0.0839, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0849, Reg Loss = 9.9565, Reconstruct Loss = 0.0326, Cls Loss = 0.0815, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0830, Reg Loss = 9.9208, Reconstruct Loss = 0.0322, 

100%|██████████| 157/157 [00:04<00:00, 32.04it/s]


Epoch [21/50], Validation Loss: 4.7418, Validation Accuracy: 48.76%



Iteration 0: Loss = 0.0759, Reg Loss = 10.9435, Reconstruct Loss = 0.0259, Cls Loss = 0.0732, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0807, Reg Loss = 9.6737, Reconstruct Loss = 0.0286, Cls Loss = 0.0777, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0759, Reg Loss = 9.8323, Reconstruct Loss = 0.0271, Cls Loss = 0.0731, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0765, Reg Loss = 9.6278, Reconstruct Loss = 0.0283, Cls Loss = 0.0736, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0793, Reg Loss = 9.6891, Reconstruct Loss = 0.0283, Cls Loss = 0.0764, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0795, Reg Loss = 9.8983, Reconstruct Loss = 0.0285, Cls Loss = 0.0766, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0788, Reg Loss = 10.0045, Reconstruct Loss = 0.0287, Cls Loss = 0.0758, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0806, Reg Loss = 9.9443, Reconstruct Loss = 0.0289, 

100%|██████████| 157/157 [00:04<00:00, 32.65it/s]


Epoch [22/50], Validation Loss: 6.0743, Validation Accuracy: 42.64%



Iteration 0: Loss = 0.1271, Reg Loss = 10.1052, Reconstruct Loss = 0.0335, Cls Loss = 0.1237, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0833, Reg Loss = 10.0462, Reconstruct Loss = 0.0313, Cls Loss = 0.0801, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0802, Reg Loss = 10.0843, Reconstruct Loss = 0.0310, Cls Loss = 0.0770, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0817, Reg Loss = 10.0049, Reconstruct Loss = 0.0300, Cls Loss = 0.0786, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0839, Reg Loss = 9.8252, Reconstruct Loss = 0.0300, Cls Loss = 0.0808, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0833, Reg Loss = 9.6370, Reconstruct Loss = 0.0300, Cls Loss = 0.0802, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0832, Reg Loss = 9.5799, Reconstruct Loss = 0.0299, Cls Loss = 0.0801, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0833, Reg Loss = 9.5642, Reconstruct Loss = 0.0299

100%|██████████| 157/157 [00:04<00:00, 33.00it/s]


Epoch [23/50], Validation Loss: 5.3468, Validation Accuracy: 46.19%



Iteration 0: Loss = 0.0409, Reg Loss = 10.6095, Reconstruct Loss = 0.0240, Cls Loss = 0.0384, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0802, Reg Loss = 10.5454, Reconstruct Loss = 0.0260, Cls Loss = 0.0775, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0827, Reg Loss = 10.0808, Reconstruct Loss = 0.0271, Cls Loss = 0.0799, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0836, Reg Loss = 10.0341, Reconstruct Loss = 0.0265, Cls Loss = 0.0809, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0816, Reg Loss = 10.1012, Reconstruct Loss = 0.0259, Cls Loss = 0.0789, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0830, Reg Loss = 10.2020, Reconstruct Loss = 0.0255, Cls Loss = 0.0803, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0825, Reg Loss = 10.1446, Reconstruct Loss = 0.0253, Cls Loss = 0.0799, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0838, Reg Loss = 9.9618, Reconstruct Loss = 0.0

100%|██████████| 157/157 [00:04<00:00, 32.56it/s]


Epoch [24/50], Validation Loss: 7.2780, Validation Accuracy: 42.47%



Iteration 0: Loss = 0.0753, Reg Loss = 8.8682, Reconstruct Loss = 0.0249, Cls Loss = 0.0728, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0820, Reg Loss = 8.7428, Reconstruct Loss = 0.0282, Cls Loss = 0.0791, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0774, Reg Loss = 9.3528, Reconstruct Loss = 0.0298, Cls Loss = 0.0743, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0795, Reg Loss = 9.6577, Reconstruct Loss = 0.0287, Cls Loss = 0.0765, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0770, Reg Loss = 9.6002, Reconstruct Loss = 0.0284, Cls Loss = 0.0741, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0777, Reg Loss = 9.6565, Reconstruct Loss = 0.0287, Cls Loss = 0.0747, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0785, Reg Loss = 9.7202, Reconstruct Loss = 0.0281, Cls Loss = 0.0756, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0780, Reg Loss = 9.6141, Reconstruct Loss = 0.0280, Cl

100%|██████████| 157/157 [00:04<00:00, 32.68it/s]


Epoch [25/50], Validation Loss: 12.0483, Validation Accuracy: 32.03%



Iteration 0: Loss = 0.1893, Reg Loss = 9.8908, Reconstruct Loss = 0.0222, Cls Loss = 0.1870, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0794, Reg Loss = 9.4277, Reconstruct Loss = 0.0270, Cls Loss = 0.0766, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0759, Reg Loss = 9.3033, Reconstruct Loss = 0.0272, Cls Loss = 0.0731, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0763, Reg Loss = 9.1747, Reconstruct Loss = 0.0270, Cls Loss = 0.0735, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0742, Reg Loss = 9.0970, Reconstruct Loss = 0.0263, Cls Loss = 0.0714, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0731, Reg Loss = 9.0815, Reconstruct Loss = 0.0268, Cls Loss = 0.0703, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0741, Reg Loss = 9.0289, Reconstruct Loss = 0.0268, Cls Loss = 0.0713, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0755, Reg Loss = 9.0555, Reconstruct Loss = 0.0268, C

100%|██████████| 157/157 [00:04<00:00, 32.35it/s]


Epoch [26/50], Validation Loss: 7.4796, Validation Accuracy: 42.26%



Iteration 0: Loss = 0.0615, Reg Loss = 10.0244, Reconstruct Loss = 0.0243, Cls Loss = 0.0590, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0619, Reg Loss = 9.0721, Reconstruct Loss = 0.0247, Cls Loss = 0.0594, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0650, Reg Loss = 9.2333, Reconstruct Loss = 0.0255, Cls Loss = 0.0624, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0678, Reg Loss = 9.1206, Reconstruct Loss = 0.0267, Cls Loss = 0.0650, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0683, Reg Loss = 9.0222, Reconstruct Loss = 0.0276, Cls Loss = 0.0654, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0682, Reg Loss = 9.0381, Reconstruct Loss = 0.0278, Cls Loss = 0.0653, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0692, Reg Loss = 9.1350, Reconstruct Loss = 0.0279, Cls Loss = 0.0663, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0705, Reg Loss = 9.1678, Reconstruct Loss = 0.0276, C

100%|██████████| 157/157 [00:04<00:00, 32.63it/s]


Epoch [27/50], Validation Loss: 15.2969, Validation Accuracy: 25.14%



Iteration 0: Loss = 0.0755, Reg Loss = 9.9238, Reconstruct Loss = 0.0266, Cls Loss = 0.0728, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0691, Reg Loss = 9.4259, Reconstruct Loss = 0.0258, Cls Loss = 0.0664, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0704, Reg Loss = 9.1049, Reconstruct Loss = 0.0270, Cls Loss = 0.0676, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0694, Reg Loss = 9.0601, Reconstruct Loss = 0.0266, Cls Loss = 0.0667, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0689, Reg Loss = 9.0925, Reconstruct Loss = 0.0262, Cls Loss = 0.0662, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0693, Reg Loss = 9.1411, Reconstruct Loss = 0.0264, Cls Loss = 0.0665, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0687, Reg Loss = 9.2283, Reconstruct Loss = 0.0262, Cls Loss = 0.0660, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0696, Reg Loss = 9.1732, Reconstruct Loss = 0.0262, C

100%|██████████| 157/157 [00:04<00:00, 32.96it/s]


Epoch [28/50], Validation Loss: 7.8456, Validation Accuracy: 40.84%



Iteration 0: Loss = 0.0375, Reg Loss = 8.3818, Reconstruct Loss = 0.0250, Cls Loss = 0.0349, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0703, Reg Loss = 8.8478, Reconstruct Loss = 0.0251, Cls Loss = 0.0677, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0711, Reg Loss = 9.0680, Reconstruct Loss = 0.0246, Cls Loss = 0.0686, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0707, Reg Loss = 9.2497, Reconstruct Loss = 0.0244, Cls Loss = 0.0682, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0713, Reg Loss = 9.1665, Reconstruct Loss = 0.0239, Cls Loss = 0.0689, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0713, Reg Loss = 9.0317, Reconstruct Loss = 0.0240, Cls Loss = 0.0688, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0708, Reg Loss = 9.0345, Reconstruct Loss = 0.0240, Cls Loss = 0.0683, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0696, Reg Loss = 9.0382, Reconstruct Loss = 0.0240, Cl

100%|██████████| 157/157 [00:04<00:00, 33.39it/s]


Epoch [29/50], Validation Loss: 13.7934, Validation Accuracy: 28.11%



Iteration 0: Loss = 0.0633, Reg Loss = 8.9175, Reconstruct Loss = 0.0211, Cls Loss = 0.0611, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0649, Reg Loss = 9.9681, Reconstruct Loss = 0.0233, Cls Loss = 0.0625, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0704, Reg Loss = 9.7885, Reconstruct Loss = 0.0239, Cls Loss = 0.0679, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0696, Reg Loss = 9.8146, Reconstruct Loss = 0.0244, Cls Loss = 0.0671, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0699, Reg Loss = 9.8287, Reconstruct Loss = 0.0243, Cls Loss = 0.0674, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0680, Reg Loss = 9.7213, Reconstruct Loss = 0.0243, Cls Loss = 0.0655, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0676, Reg Loss = 9.6682, Reconstruct Loss = 0.0242, Cls Loss = 0.0651, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0672, Reg Loss = 9.7031, Reconstruct Loss = 0.0242, C

100%|██████████| 157/157 [00:04<00:00, 33.75it/s]


Epoch [30/50], Validation Loss: 22.8890, Validation Accuracy: 26.01%



Iteration 0: Loss = 0.0559, Reg Loss = 10.3679, Reconstruct Loss = 0.0219, Cls Loss = 0.0536, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0584, Reg Loss = 10.2665, Reconstruct Loss = 0.0230, Cls Loss = 0.0560, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0573, Reg Loss = 9.9633, Reconstruct Loss = 0.0237, Cls Loss = 0.0548, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0590, Reg Loss = 9.9128, Reconstruct Loss = 0.0234, Cls Loss = 0.0566, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0606, Reg Loss = 9.8846, Reconstruct Loss = 0.0232, Cls Loss = 0.0582, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0605, Reg Loss = 9.8122, Reconstruct Loss = 0.0230, Cls Loss = 0.0581, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0609, Reg Loss = 9.9377, Reconstruct Loss = 0.0229, Cls Loss = 0.0585, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0623, Reg Loss = 9.9302, Reconstruct Loss = 0.0228,

100%|██████████| 157/157 [00:04<00:00, 32.76it/s]


Epoch [31/50], Validation Loss: 19.1323, Validation Accuracy: 21.71%



Iteration 0: Loss = 0.0469, Reg Loss = 9.1809, Reconstruct Loss = 0.0313, Cls Loss = 0.0437, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0605, Reg Loss = 9.7012, Reconstruct Loss = 0.0248, Cls Loss = 0.0580, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0606, Reg Loss = 9.8043, Reconstruct Loss = 0.0238, Cls Loss = 0.0581, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0601, Reg Loss = 9.9345, Reconstruct Loss = 0.0232, Cls Loss = 0.0576, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0616, Reg Loss = 9.9308, Reconstruct Loss = 0.0227, Cls Loss = 0.0592, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0626, Reg Loss = 10.0563, Reconstruct Loss = 0.0227, Cls Loss = 0.0603, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0623, Reg Loss = 9.9819, Reconstruct Loss = 0.0231, Cls Loss = 0.0599, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0620, Reg Loss = 9.9186, Reconstruct Loss = 0.0229, 

100%|██████████| 157/157 [00:04<00:00, 33.13it/s]


Epoch [32/50], Validation Loss: 15.3242, Validation Accuracy: 29.48%



Iteration 0: Loss = 0.0378, Reg Loss = 9.6511, Reconstruct Loss = 0.0245, Cls Loss = 0.0352, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0599, Reg Loss = 9.4878, Reconstruct Loss = 0.0243, Cls Loss = 0.0574, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0589, Reg Loss = 9.5013, Reconstruct Loss = 0.0238, Cls Loss = 0.0565, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0594, Reg Loss = 9.5181, Reconstruct Loss = 0.0233, Cls Loss = 0.0569, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0627, Reg Loss = 9.5012, Reconstruct Loss = 0.0232, Cls Loss = 0.0603, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0636, Reg Loss = 9.4992, Reconstruct Loss = 0.0241, Cls Loss = 0.0611, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0626, Reg Loss = 9.5673, Reconstruct Loss = 0.0247, Cls Loss = 0.0601, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0639, Reg Loss = 9.6218, Reconstruct Loss = 0.0247, C

100%|██████████| 157/157 [00:04<00:00, 31.85it/s]


Epoch [33/50], Validation Loss: 20.6882, Validation Accuracy: 26.99%



Iteration 0: Loss = 0.0990, Reg Loss = 9.4353, Reconstruct Loss = 0.0213, Cls Loss = 0.0968, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0655, Reg Loss = 9.7059, Reconstruct Loss = 0.0226, Cls Loss = 0.0631, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0604, Reg Loss = 9.6730, Reconstruct Loss = 0.0231, Cls Loss = 0.0580, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0628, Reg Loss = 9.7257, Reconstruct Loss = 0.0236, Cls Loss = 0.0604, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0624, Reg Loss = 9.7734, Reconstruct Loss = 0.0236, Cls Loss = 0.0599, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0628, Reg Loss = 9.7722, Reconstruct Loss = 0.0236, Cls Loss = 0.0603, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0629, Reg Loss = 9.7805, Reconstruct Loss = 0.0236, Cls Loss = 0.0604, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0635, Reg Loss = 9.8533, Reconstruct Loss = 0.0236, C

100%|██████████| 157/157 [00:04<00:00, 33.11it/s]


Epoch [34/50], Validation Loss: 24.0492, Validation Accuracy: 25.61%



Iteration 0: Loss = 0.0627, Reg Loss = 9.6338, Reconstruct Loss = 0.0262, Cls Loss = 0.0600, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0607, Reg Loss = 9.7523, Reconstruct Loss = 0.0232, Cls Loss = 0.0583, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0616, Reg Loss = 9.9841, Reconstruct Loss = 0.0226, Cls Loss = 0.0592, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0606, Reg Loss = 10.2219, Reconstruct Loss = 0.0224, Cls Loss = 0.0583, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0632, Reg Loss = 10.2605, Reconstruct Loss = 0.0222, Cls Loss = 0.0609, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0629, Reg Loss = 10.3212, Reconstruct Loss = 0.0221, Cls Loss = 0.0605, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0621, Reg Loss = 10.2849, Reconstruct Loss = 0.0223, Cls Loss = 0.0598, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0619, Reg Loss = 10.3557, Reconstruct Loss = 0.02

100%|██████████| 157/157 [00:04<00:00, 32.27it/s]


Epoch [35/50], Validation Loss: 19.4493, Validation Accuracy: 27.93%



Iteration 0: Loss = 0.0317, Reg Loss = 9.8078, Reconstruct Loss = 0.0241, Cls Loss = 0.0292, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0586, Reg Loss = 9.9515, Reconstruct Loss = 0.0252, Cls Loss = 0.0560, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0608, Reg Loss = 9.8764, Reconstruct Loss = 0.0247, Cls Loss = 0.0583, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0595, Reg Loss = 10.0852, Reconstruct Loss = 0.0244, Cls Loss = 0.0570, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0596, Reg Loss = 10.1747, Reconstruct Loss = 0.0241, Cls Loss = 0.0571, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0603, Reg Loss = 10.2285, Reconstruct Loss = 0.0240, Cls Loss = 0.0578, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0602, Reg Loss = 10.2324, Reconstruct Loss = 0.0241, Cls Loss = 0.0577, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0606, Reg Loss = 10.2384, Reconstruct Loss = 0.02

100%|██████████| 157/157 [00:04<00:00, 32.71it/s]


Epoch [36/50], Validation Loss: 22.0056, Validation Accuracy: 27.84%



Iteration 0: Loss = 0.0579, Reg Loss = 10.2134, Reconstruct Loss = 0.0246, Cls Loss = 0.0554, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0560, Reg Loss = 10.0282, Reconstruct Loss = 0.0212, Cls Loss = 0.0538, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0572, Reg Loss = 10.0011, Reconstruct Loss = 0.0211, Cls Loss = 0.0550, Learning rate = 1.0000e-03
Iteration 150: Loss = 0.0604, Reg Loss = 9.9893, Reconstruct Loss = 0.0209, Cls Loss = 0.0582, Learning rate = 1.0000e-03
Iteration 200: Loss = 0.0600, Reg Loss = 9.9510, Reconstruct Loss = 0.0209, Cls Loss = 0.0578, Learning rate = 1.0000e-03
Iteration 250: Loss = 0.0611, Reg Loss = 9.9730, Reconstruct Loss = 0.0207, Cls Loss = 0.0589, Learning rate = 1.0000e-03
Iteration 300: Loss = 0.0606, Reg Loss = 9.9686, Reconstruct Loss = 0.0209, Cls Loss = 0.0584, Learning rate = 1.0000e-03
Iteration 350: Loss = 0.0617, Reg Loss = 9.9579, Reconstruct Loss = 0.0210

100%|██████████| 157/157 [00:04<00:00, 33.18it/s]


Epoch [37/50], Validation Loss: 19.1785, Validation Accuracy: 28.05%



Iteration 0: Loss = 0.0302, Reg Loss = 10.2267, Reconstruct Loss = 0.0198, Cls Loss = 0.0281, Learning rate = 1.0000e-03
Iteration 50: Loss = 0.0670, Reg Loss = 10.8312, Reconstruct Loss = 0.0209, Cls Loss = 0.0648, Learning rate = 1.0000e-03
Iteration 100: Loss = 0.0613, Reg Loss = 10.6512, Reconstruct Loss = 0.0215, Cls Loss = 0.0590, Learning rate = 1.0000e-03


KeyboardInterrupt: 

In [None]:
wandb.finish()

0,1
Cls Loss,█████▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▃▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▁▁
Learning rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss,█▇▇▇▆▇█▇▇▆▆▆▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▃▂▂▂▂▁▁▂▂▁
Reconstruct Loss,█▇▇▅▅▅▇▅▅▅▄▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁
Reg Loss,▁▁▁▇▇█████▇▆▆▅▄▆▆▆▅▅▄▄▄▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇
Training accuracy,▁▁▂▂▃▄▄▄▅▆▆▇▇▇██
Validation Accuracy,▂▁▂▁▁▁▂▆▆▇██▇▇▆█
Validation Loss,▃▃▂▃▃▃▂▂▁▂▃▄▂██▄

0,1
Cls Loss,0.08349
Learning rate,0.001
Loss,0.0917
Reconstruct Loss,0.00755
Reg Loss,6.62567
Training accuracy,0.97296
Validation Accuracy,0.6941
Validation Loss,1.52875


### 7 Testing loop

In [None]:
saved_hypernet_path = args.training.save_model_path + '/cifar10_nerf_best.pth'

In [None]:
saved_hypernet_path

'toy/experiments/densenet_train_36_48_mlp_256_4_coordnoise_unsmooth_50e_bs64/cifar10_nerf_best.pth'

In [None]:
hyper_model_test = get_hypernetwork(args, number_param)

Hyper model type: resmlp
Using scalar 0.1
num_freqs:  16 <class 'int'>


In [None]:
checkpoint = torch.load(saved_hypernet_path, map_location="cpu")  # or "cuda" if using GPU
hyper_model_test.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [None]:
for hidden_dim in range(12, 49):
    # Create a model for this given dimension
    model_trained = create_model(args.model.type,
                                 layers=args.model.layers,
                                 growth=args.model.growth,
                                 compression=args.model.compression,
                                 bottleneck=args.model.bottleneck,
                                 drop_rate=args.model.drop_rate,
                                 path=args.model.pretrained_path,
                                 hidden_dim=hidden_dim).to(device)
    
    # If EMA is specified, apply it
    if ema:
        print('Applying EMA')
        ema.apply()

    # Sample the merged model
    accumulated_model = sample_merge_model(hyper_model_test, model_trained, args, K=100)

    # Validate the merged model
    val_loss, acc = validate_single(accumulated_model, val_loader, val_criterion, args=args)

    # If EMA is specified, restore the original weights after applying EMA
    if ema:
        ema.restore()  # Restore the original weights after applying 
        
    # Save the model
    save_name = os.path.join(args.training.save_model_path, f"cifar10_{accumulated_model.__class__.__name__}_dim{hidden_dim}_single.pth")
    torch.save(accumulated_model.state_dict(),save_name)

    # Print the results
    print(f"Test using model {args.model}: hidden_dim {hidden_dim}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")
    print('\n')

    # Define the directory and filename structure
    filename = f"cifar10_results_{args.experiment.name}.txt"
    filepath = os.path.join(args.training.save_model_path, filename)

    # Write the results. 'a' is used to append the results; a new file will be created if it doesn't exist.
    with open(filepath, "a") as file:
        file.write(f"Hidden_dim: {hidden_dim}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%\n")

Loading model from toy/experiments/densenet_bc_40_12_baseline/densenet_bc_40_12_cifar10_baseline_best.pth
Applying EMA


100%|██████████| 157/157 [00:03<00:00, 42.70it/s]

Test using model {'type': 'DenseNet', 'pretrained_path': 'toy/experiments/densenet_bc_40_12_baseline/densenet_bc_40_12_cifar10_baseline_best.pth', 'layers': 40, 'growth': 12, 'compression': 0.5, 'bottleneck': True, 'drop_rate': 0.0}: hidden_dim 48, Validation Loss: 1.3203, Validation Accuracy: 69.94%





