In [1]:
import torch
import torch.nn as nn
import torch_pruning as tp
import matplotlib.pyplot as plt
from torch import optim
import os
import numpy as np
import copy

import torchvision
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader


In [2]:

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_BASE_NAME = "mobilenet_v2-b0353104"

### Data Loaders (Using your provided function structure)

In [3]:
def get_data_loaders(data_dir_path, batch_size=64, val_split=0.1, seed=42):
    abs_data_dir = os.path.abspath(data_dir_path)
    print(f"Attempting to load CIFAR-10 from pre-downloaded directory: {abs_data_dir}")
    expected_cifar_folder = os.path.join(abs_data_dir, 'cifar-10-batches-py')
    if not os.path.exists(expected_cifar_folder):
        print(f"ERROR: Expected CIFAR-10 data folder '{expected_cifar_folder}' not found!")
        raise FileNotFoundError(f"CIFAR-10 data not found at {expected_cifar_folder}")

    transform_cifar = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    try:
        full_train_dataset = torchvision.datasets.CIFAR10(root=abs_data_dir, train=True, download=False, transform=transform_cifar)
        test_dataset = torchvision.datasets.CIFAR10(root=abs_data_dir, train=False, download=False, transform=transform_cifar)
    except RuntimeError as e:
        print(f"ERROR: Failed to load CIFAR-10 from {abs_data_dir}. Error: {e}"); raise
    val_size = int(len(full_train_dataset) * val_split); train_size = len(full_train_dataset) - val_size
    generator = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = torch.utils.data.random_split(full_train_dataset, [train_size, val_size], generator=generator)
    num_avail_cpus = len(os.sched_getaffinity(0)) if hasattr(os, 'sched_getaffinity') else os.cpu_count()
    num_workers_val = min(num_avail_cpus, 4) if num_avail_cpus is not None else 2
    pin_memory_flag = True if DEVICE.type == 'cuda' else False
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers_val, pin_memory=pin_memory_flag)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers_val, pin_memory=pin_memory_flag)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers_val, pin_memory=pin_memory_flag)
    print(f"DataLoaders: Train {len(train_dataset)}, Val {len(val_dataset)}, Test {len(test_dataset)}")
    return train_loader, val_loader, test_loader

### Torchvision MobileNetV2 Model Helper

In [4]:
def get_torchvision_mobilenetv2_structure(num_classes_local, device_to_map=DEVICE):
    """
    Creates the torchvision.models.mobilenet_v2 STRUCTURE (weights=None)
    and adapts its classifier for the given number of classes.
    This is used as a template to load a state_dict.
    """
    # print(f"Creating torchvision MobileNetV2 structure (weights=None) for {num_classes_local} classes.")
    tv_model = models.mobilenet_v2(weights=None, num_classes=num_classes_local) # num_classes adapts final layer

    # Ensure the classifier is correctly adapted if num_classes was for something else initially
    # For models.mobilenet_v2(weights=None, num_classes=X), the classifier is already X.
    # This re-adaptation might be redundant if num_classes in constructor works as expected.
    if isinstance(tv_model.classifier, nn.Sequential) and len(tv_model.classifier) > 0:
        last_layer_index = -1 # Typically the Linear layer is last or second last
        if isinstance(tv_model.classifier[last_layer_index], nn.Linear):
             current_out_features = tv_model.classifier[last_layer_index].out_features
             if current_out_features != num_classes_local:
                 print(f"Adapting classifier output from {current_out_features} to {num_classes_local} (should be redundant if constructor worked).")
                 num_ftrs_tv = tv_model.classifier[last_layer_index].in_features
                 tv_model.classifier[last_layer_index] = torch.nn.Linear(num_ftrs_tv, num_classes_local)
        elif len(tv_model.classifier) > 1 and isinstance(tv_model.classifier[-2], nn.Linear): # Check if last is dropout
            current_out_features = tv_model.classifier[-2].out_features
            if current_out_features != num_classes_local:
                print(f"Adapting classifier output from {current_out_features} to {num_classes_local} (should be redundant if constructor worked).")
                num_ftrs_tv = tv_model.classifier[-2].in_features
                tv_model.classifier[-2] = torch.nn.Linear(num_ftrs_tv, num_classes_local)
    return tv_model.to(device_to_map)


def get_ignored_layer_for_torchvision_mobilenetv2(tv_model):
    if isinstance(tv_model.classifier, nn.Sequential) and len(tv_model.classifier) > 0:
        if isinstance(tv_model.classifier[-1], nn.Linear): return [tv_model.classifier[-1]]
        elif len(tv_model.classifier) > 1 and isinstance(tv_model.classifier[-2], nn.Linear): return [tv_model.classifier[-2]]
    elif isinstance(tv_model.classifier, nn.Linear): return [tv_model.classifier]
    print("Warning: Could not identify classifier layer in get_ignored_layer_for_torchvision_mobilenetv2."); return []



### Utility Functions (Save/Load, MACs, Evaluate)

In [5]:
def save_model_as_onnx(model_to_save, example_input_cpu_for_onnx, output_path_onnx):
    model_to_save.eval()
    model_cpu_for_export = model_to_save.to('cpu')
    torch.onnx.export(
        model_cpu_for_export, example_input_cpu_for_onnx, output_path_onnx, export_params=True, opset_version=13,
        input_names=['input'], output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"✅ Model saved as ONNX to {output_path_onnx}")

def load_model_state(model_structure_provider_func, path_to_load_pth, device_to_load, num_classes_for_model_struct):
    model_shell = model_structure_provider_func(num_classes_local=num_classes_for_model_struct, device_to_map=device_to_load)
    model_shell.load_state_dict(torch.load(path_to_load_pth, map_location=device_to_load))
    print(f"✅ State loaded into model structure from: {path_to_load_pth}")
    return model_shell


def calculate_macs_params(model_to_calc, example_input_for_calc):
    # Ensure model and input are on the same device for tp.utils
    target_device = example_input_for_calc.device
    model_on_target_device = model_to_calc.to(target_device)
    macs, params = tp.utils.count_ops_and_params(model_on_target_device, example_input_for_calc)
    return macs, params

def save_model(model_to_save, path_to_save_pth, example_input_cpu_for_onnx_save=None):
    os.makedirs(os.path.dirname(path_to_save_pth), exist_ok=True)
    torch.save(model_to_save.state_dict(), path_to_save_pth)
    print(f"✅ Model state_dict saved to {path_to_save_pth}")
    if example_input_cpu_for_onnx_save is not None:
        onnx_output_path = path_to_save_pth.replace('.pth', '.onnx')
        save_model_as_onnx(model_to_save, example_input_cpu_for_onnx_save, onnx_output_path)

def load_torchvision_model_state(model_constructor_func, path_to_load_pth, device_to_load, num_classes_for_model):
    # model_constructor_func here is 'get_torchvision_mobilenetv2'
    # It needs num_classes and use_pretrained_weights=False to get the correct structure before loading state_dict
    loaded_model = model_constructor_func(num_classes=num_classes_for_model, use_pretrained_weights=False, device_to_map=device_to_load)
    loaded_model.load_state_dict(torch.load(path_to_load_pth, map_location=device_to_load))
    print(f"✅ Torchvision model state loaded from {path_to_load_pth} to {device_to_load}")
    return loaded_model

def evaluate_model_with_loss(model_to_eval, data_loader_for_eval, example_input_for_macs, criterion_for_loss, device_for_eval):
    model_to_eval.eval()
    model_to_eval.to(device_for_eval)
    macs_val, params_val = calculate_macs_params(model_to_eval, example_input_for_macs.to(device_for_eval))
    size_mib_val = params_val * 4 / (1024 * 1024) # Assuming float32
    correct_preds = 0; total_samples_eval = 0; accumulated_loss = 0.0
    with torch.no_grad():
        for data_batch in data_loader_for_eval:
            inputs_batch, labels_batch = [d.to(device_for_eval) for d in data_batch]
            outputs_batch = model_to_eval(inputs_batch)
            loss_batch = criterion_for_loss(outputs_batch, labels_batch)
            accumulated_loss += loss_batch.item() * inputs_batch.size(0)
            _, predicted_labels = torch.max(outputs_batch.data, 1)
            total_samples_eval += labels_batch.size(0)
            correct_preds += (predicted_labels == labels_batch).sum().item()
    avg_loss_val = accumulated_loss / total_samples_eval if total_samples_eval > 0 else float('nan')
    accuracy_val = 100 * correct_preds / total_samples_eval if total_samples_eval > 0 else 0
    return {'macs': macs_val, 'params': params_val, 'size_mib': size_mib_val, 'accuracy': accuracy_val, 'loss': avg_loss_val}

### Pruning Function (Adapted for explicit ignored_layers and torchvision BNs)

In [6]:
def gr_prune_model_by_ratio(model_to_prune_input, example_input_dev_prune, strategy_config_prune,
                            target_pruning_ratio_prune, explicit_ignored_layers_list=None,
                            iterative_steps_general_pruner=5,
                            iterative_steps_taylor_pruner=5):
    pruned_model_state = copy.deepcopy(model_to_prune_input) # Work on a copy
    pruned_model_state.to(example_input_dev_prune.device)
    initial_macs_prune, _ = calculate_macs_params(pruned_model_state, example_input_dev_prune)
    print(f"Initial MACs for ratio {target_pruning_ratio_prune:.2f}: {initial_macs_prune / 1e9:.3f} G")

    current_ignored_layers = explicit_ignored_layers_list if explicit_ignored_layers_list is not None else []
    if not current_ignored_layers: print("Warning: No explicit ignored_layers provided to pruner.")


    if target_pruning_ratio_prune == 0.0:
        print("Target pruning ratio 0.0 - no pruning needed.")
        return pruned_model_state

    if isinstance(strategy_config_prune['importance'], tp.importance.TaylorImportance):
        print(f"TaylorImportance: ratio {target_pruning_ratio_prune:.2f}, {iterative_steps_taylor_pruner} backward passes.")

        # BatchNorm layers in torchvision.models.mobilenet_v2.features sequence
        # that might receive 1x1 spatial input with 32x32 overall input image size.
        # Feature map sizes before these layers if input is 32x32:
        # features[13] (ConvBNReLU): Output before this is block 160->960->160, 2x2 spatially.
        #                      Inside InvertedResidual for features[13] (stride 2, 96->160): hidden=576
        #                      bn after depthwise conv with stride 2 receives 1x1. This is features[13].conv[4] (BN)
        # features[14,15,16] (ConvBNReLU): block 160->960->160, 1x1 spatially. (stride 1, 160->160)
        #                      All BNs inside these will get 1x1. (conv[1], conv[4], conv[7])
        # features[17] (ConvBNReLU): block 160->960->320, 1x1 spatially. (stride 1, 160->320)
        #                      All BNs inside this will get 1x1.
        # features[18] (ConvBN2d): This is the final 1x1 conv before pooling. its BN gets 1x1.

        bn_layers_to_eval_during_taylor = []
        # Helper to add BNs from an InvertedResidual-like block (ConvBNRelu -> DwConvBNRelu -> ConvBN)
        def add_bns_from_inv_res_if_1x1(block_module, spatial_is_1x1_input):
            if hasattr(block_module, 'conv') and isinstance(block_module.conv, nn.Sequential):
                # conv[1] is bn after first 1x1 conv
                if spatial_is_1x1_input and isinstance(block_module.conv[1], nn.BatchNorm2d):
                     bn_layers_to_eval_during_taylor.append(block_module.conv[1])
                # conv[4] is bn after depthwise conv
                # Spatial becomes 1x1 if input was 2x2 and stride=2 OR if input was already 1x1
                # For simplicity here, we assume the outer logic determines if this block *overall* sees 1x1 due to prior stages
                if spatial_is_1x1_input and isinstance(block_module.conv[4], nn.BatchNorm2d):
                     bn_layers_to_eval_during_taylor.append(block_module.conv[4])
                # conv[6] (careful with indexing if last is just BN not ConvBN) / or last BN in the sequence
                if hasattr(block_module.conv, '6') and isinstance(block_module.conv[6], nn.BatchNorm2d) and spatial_is_1x1_input: # if structure is C-BN-ACT-DW-BN-ACT-C-BN
                     bn_layers_to_eval_during_taylor.append(block_module.conv[6])
                elif isinstance(block_module.conv[-1], nn.BatchNorm2d) and spatial_is_1x1_input: # More robust: last layer of seq is BN
                    bn_layers_to_eval_during_taylor.append(block_module.conv[-1])


        # Based on typical MobileNetV2 structure from torchvision.features:
        # features are a nn.Sequential. Strides affect subsequent blocks.
        # 32x32 -> [0] Conv s2 -> 16x16
        # -> [1] InvRes s1 -> 16x16
        # -> [2,3] InvRes s2 (block 3) -> 8x8
        # -> [4,5,6] InvRes s2 (block 6) -> 4x4
        # -> [7,8,9,10] InvRes s2 (block 10) -> 2x2
        # -> [11,12,13] InvRes s1 (block 13, last has stride 2, features[13]) -> features[13] makes it 1x1
        # The following blocks [14] through [17] and the final conv[18] will operate on 1x1.

        if hasattr(pruned_model_state, 'features') and isinstance(pruned_model_state.features, nn.Sequential):
            # Block starting from features[13] output leads to 1x1, so its bn_after_dw may get 1x1 if input to block was 2x2 and stride=2.
            # It's safer to assume blocks from where spatial dim is consistently 1x1.
            # features[13] is type InvertedResidual.
            # Inside models.mobilenet_v2, the blocks are directly in features.
            # features[13].conv[4] (BN after DW conv s=2 on 2x2 input)
            if len(pruned_model_state.features) > 13 and hasattr(pruned_model_state.features[13], 'conv') and isinstance(pruned_model_state.features[13].conv, nn.Sequential):
                 if isinstance(pruned_model_state.features[13].conv[4], nn.BatchNorm2d): bn_layers_to_eval_during_taylor.append(pruned_model_state.features[13].conv[4])
                 if isinstance(pruned_model_state.features[13].conv[-1], nn.BatchNorm2d): bn_layers_to_eval_during_taylor.append(pruned_model_state.features[13].conv[-1])


            # Blocks features[14] through features[17] operate on 1x1 spatial feature maps
            for k in range(14, 18): # Indices for blocks that get 1x1 input
                if len(pruned_model_state.features) > k and hasattr(pruned_model_state.features[k], 'conv') and isinstance(pruned_model_state.features[k].conv, nn.Sequential):
                    # Add all BNs from these blocks: conv[1], conv[4], conv[-1] (final BN)
                    if isinstance(pruned_model_state.features[k].conv[1], nn.BatchNorm2d): bn_layers_to_eval_during_taylor.append(pruned_model_state.features[k].conv[1])
                    if isinstance(pruned_model_state.features[k].conv[4], nn.BatchNorm2d): bn_layers_to_eval_during_taylor.append(pruned_model_state.features[k].conv[4])
                    if isinstance(pruned_model_state.features[k].conv[-1], nn.BatchNorm2d): bn_layers_to_eval_during_taylor.append(pruned_model_state.features[k].conv[-1])

            # The BN in the final 1x1 conv layer (features[18] is often Conv2dNormActivation)
            if len(pruned_model_state.features) > 18 and hasattr(pruned_model_state.features[18], '1') and isinstance(pruned_model_state.features[18][1], nn.BatchNorm2d): # Conv2dNormActivation [0]=conv, [1]=bn
                 bn_layers_to_eval_during_taylor.append(pruned_model_state.features[18][1])

        if bn_layers_to_eval_during_taylor: print(f"Temporarily setting {len(bn_layers_to_eval_during_taylor)} BNs to eval for Taylor.")

        pruner_for_taylor = strategy_config_prune['pruner'](
            pruned_model_state, example_input_dev_prune, importance=strategy_config_prune['importance'],
            iterative_steps=iterative_steps_taylor_pruner, ch_sparsity=target_pruning_ratio_prune,
            root_module_types=[nn.Conv2d], ignored_layers=current_ignored_layers
        )
        original_bn_modes_taylor = {}
        for i in range(iterative_steps_taylor_pruner):
            pruned_model_state.train()
            for bn_l in bn_layers_to_eval_during_taylor: original_bn_modes_taylor[bn_l] = bn_l.training; bn_l.eval()
            pruned_model_state.zero_grad()
            loss_taylor = pruned_model_state(example_input_dev_prune.clone()).mean()
            try: loss_taylor.backward()
            except Exception as e_bw:
                print(f"Backward error in Taylor step {i+1}: {e_bw}"); raise
            finally: # Ensure BN modes are restored
                 for bn_l, mode in original_bn_modes_taylor.items(): bn_l.training = mode
                 original_bn_modes_taylor.clear()
            pruned_model_state.eval()
            pruner_for_taylor.step()
            macs_after_taylor_step, _ = calculate_macs_params(pruned_model_state, example_input_dev_prune)
            print(f"  Taylor Step {i+1}/{iterative_steps_taylor_pruner} (Pruner {pruner_for_taylor.current_step}/{pruner_for_taylor.iterative_steps}): MACs {macs_after_taylor_step / 1e9:.3f} G")
            if pruner_for_taylor.current_step == pruner_for_taylor.iterative_steps: break
    else:
        print(f"Applying {strategy_config_prune['importance'].__class__.__name__} @ ratio {target_pruning_ratio_prune:.2f} "
              f"with {iterative_steps_general_pruner} pruner steps.")
        pruned_model_state.eval() # Non-Taylor pruners generally expect eval mode
        pruner_general = strategy_config_prune['pruner'](
            pruned_model_state, example_input_dev_prune, importance=strategy_config_prune['importance'],
            iterative_steps=iterative_steps_general_pruner, ch_sparsity=target_pruning_ratio_prune,
            root_module_types=[nn.Conv2d], ignored_layers=current_ignored_layers
        )
        pruner_general.step()

    final_macs_prune, _ = calculate_macs_params(pruned_model_state, example_input_dev_prune)
    reduction_prune = (initial_macs_prune - final_macs_prune) / initial_macs_prune * 100 if initial_macs_prune > 0 else 0
    print(f"Pruning for ratio {target_pruning_ratio_prune:.2f} done: MACs {final_macs_prune / 1e9:.3f} G (Reduced by {reduction_prune:.2f}%)")
    return pruned_model_state


### Train the model (with Early Stopping)

In [7]:
def train_model(model_train, train_dl, criterion_train, optimizer_train, device_train, num_epochs_train,
                val_loader=None, log_prefix_train="", es_patience=None, es_metric='val_loss', load_best_es=True):
    model_train.to(device_train)
    history_train = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_metric_val = float('inf') if es_metric == 'val_loss' else float('-inf')
    epochs_no_improve_train = 0; best_model_state_train = None

    for epoch_train in range(num_epochs_train):
        model_train.train(); running_loss_train = 0.0; correct_train = 0; total_train = 0
        for _, data_train in enumerate(train_dl):
            inputs_train, labels_train = [d.to(device_train) for d in data_train]
            optimizer_train.zero_grad(); outputs_train = model_train(inputs_train); loss_train = criterion_train(outputs_train, labels_train)
            loss_train.backward(); optimizer_train.step()
            running_loss_train += loss_train.item(); _, predicted_train = torch.max(outputs_train.data, 1)
            total_train += labels_train.size(0); correct_train += (predicted_train == labels_train).sum().item()

        epoch_loss_train = running_loss_train/len(train_dl) if len(train_dl)>0 else 0
        epoch_acc_train = 100*correct_train/total_train if total_train >0 else 0
        history_train['train_loss'].append(epoch_loss_train); history_train['train_acc'].append(epoch_acc_train)
        log_line = f"E {epoch_train+1}/{num_epochs_train} ({log_prefix_train}): Tr L={epoch_loss_train:.4f}, Acc={epoch_acc_train:.2f}%"

        if val_loader:
            model_train.eval(); running_val_loss_train = 0.0; correct_val_train = 0; total_val_train = 0
            with torch.no_grad():
                for data_val_train in val_loader:
                    inputs_val_train, labels_val_train = [d.to(device_train) for d in data_val_train]
                    outputs_val_train = model_train(inputs_val_train); val_loss_item_train = criterion_train(outputs_val_train, labels_val_train)
                    running_val_loss_train += val_loss_item_train.item(); _, predicted_val_train = torch.max(outputs_val_train.data, 1)
                    total_val_train += labels_val_train.size(0); correct_val_train += (predicted_val_train == labels_val_train).sum().item()
            epoch_val_loss_train = running_val_loss_train/len(val_loader) if len(val_loader)>0 else 0
            epoch_val_acc_train = 100*correct_val_train/total_val_train if total_val_train >0 else 0
            history_train['val_loss'].append(epoch_val_loss_train); history_train['val_acc'].append(epoch_val_acc_train)
            log_line += f", Val L={epoch_val_loss_train:.4f}, Acc={epoch_val_acc_train:.2f}%"

            if es_patience is not None:
                metric_val_check = epoch_val_loss_train if es_metric == 'val_loss' else epoch_val_acc_train
                improved_es = (metric_val_check < best_metric_val) if es_metric == 'val_loss' else (metric_val_check > best_metric_val)
                if improved_es:
                    best_metric_val = metric_val_check; epochs_no_improve_train = 0
                    if load_best_es: best_model_state_train = copy.deepcopy(model_train.state_dict())
                    log_line += " (New best)"
                else: epochs_no_improve_train += 1
        else: # No val_loader
            history_train['val_loss'].append(None); history_train['val_acc'].append(None)
            if es_patience: print(f"ES Warning ({log_prefix_train}): No val_loader provided.")
        print(log_line)
        if es_patience and val_loader and epochs_no_improve_train >= es_patience:
            print(f"EarlyStopping for '{log_prefix_train}' at epoch {epoch_train+1}.")
            if load_best_es and best_model_state_train: model_train.load_state_dict(best_model_state_train)
            break

    if load_best_es and best_model_state_train:
        final_val_metric_train = history_train[es_metric][-1] if history_train[es_metric] and history_train[es_metric][-1] is not None else None
        load_final_best = True
        if final_val_metric_train is not None: # Check if current state is already best
            if es_metric == 'val_loss': load_final_best = best_metric_val < final_val_metric_train
            else: load_final_best = best_metric_val > final_val_metric_train
        if load_final_best :
            print(f"End of Train ({log_prefix_train}): Loading best state ({es_metric}={best_metric_val:.4f}).")
            model_train.load_state_dict(best_model_state_train)
    return model_train, history_train


### Plotting Utilities

In [8]:
from torchinfo import summary as torchinfo_summary # For model architecture summary
def summarize_loaded_initial_model(model_instance, model_name, eval_metrics, example_input_for_summary, output_dir):
    """
    Saves architecture summary and evaluation metrics for a loaded initial model.
    """
    os.makedirs(output_dir, exist_ok=True)

    # 1. Save Architecture Summary
    summary_filename = os.path.join(output_dir, f"{model_name}_architecture_summary.txt")
    try:
        # Ensure model and example_input are on the same device or example_input is on CPU for summary
        # torchinfo typically handles moving model to input_data's device or CPU.
        # Let's make sure example_input_for_summary is on CPU for robust summary generation.
        example_input_summary_cpu = example_input_for_summary.to('cpu')
        model_for_summary_cpu = model_instance.to('cpu')

        model_summary_str = str(torchinfo_summary(model_for_summary_cpu,
                                                  input_data=example_input_summary_cpu, # Use CPU version of input
                                                  verbose=0,
                                                  col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], # Added kernel_size
                                                  row_settings=["var_names"]))
        with open(summary_filename, 'w') as f:
            f.write(f"Architecture Summary for: {model_name}\n\n")
            f.write(model_summary_str)
        print(f"✅ Architecture summary saved to: {summary_filename}")
        # Move model back to original device if needed, though this function gets a copy typically
        model_instance.to(example_input_for_summary.device) # If model_instance was modified

    except ImportError:
        print("torchinfo not installed. Skipping detailed architecture summary. Saving basic print(model).")
        print("Install torchinfo for detailed summaries: pip install torchinfo")
        with open(summary_filename, 'w') as f:
            f.write(f"Basic Model Structure for: {model_name}\n\n")
            f.write(str(model_instance)) # Fallback to basic model string representation
        print(f"✅ Basic model structure saved to: {summary_filename}")
    except Exception as e:
        print(f"Could not generate torchinfo summary: {e}. Saving basic print(model).")
        with open(summary_filename, 'w') as f:
            f.write(f"Basic Model Structure for: {model_name}\n\n")
            f.write(str(model_instance))
        print(f"✅ Basic model structure saved to: {summary_filename}")


    # 2. Save Evaluation Metrics as Text
    metrics_filename = os.path.join(output_dir, f"{model_name}_evaluation_metrics.txt")
    with open(metrics_filename, 'w') as f:
        f.write(f"Evaluation Metrics for: {model_name} (on Test Set)\n\n")
        if eval_metrics: # Ensure eval_metrics is not None
            for key, value in eval_metrics.items():
                if isinstance(value, float) and key not in ['macs', 'params']:
                    f.write(f"{key.capitalize()}: {value:.4f}\n")
                elif isinstance(value, (float, int)) and key in ['macs', 'params']: # int for params if not float
                     f.write(f"{key.capitalize()}: {value:,.0f} (or {value:.3e})\n") # Formatted with comma and scientific
                else:
                    f.write(f"{key.capitalize()}: {value}\n")
        else:
            f.write("No evaluation metrics provided.\n")
    print(f"✅ Evaluation metrics saved to: {metrics_filename}")

    # 3. Simple Bar Chart for key metrics (if metrics are available)
    if eval_metrics:
        metrics_to_plot_init = {
            'MACs': eval_metrics.get('macs'), # Use .get for safety
            'Params': eval_metrics.get('params'),
            'Test Loss': eval_metrics.get('loss'),
            'Test Acc (%)': eval_metrics.get('accuracy')
        }

        # Filter out None values before plotting
        valid_metrics_for_plot = {k: v for k, v in metrics_to_plot_init.items() if v is not None and not (isinstance(v, float) and np.isnan(v))}

        if len(valid_metrics_for_plot) > 0:
            num_valid_metrics = len(valid_metrics_for_plot)
            # Adjust subplot layout based on number of metrics
            ncols = 2 if num_valid_metrics > 1 else 1
            nrows = (num_valid_metrics + ncols - 1) // ncols # Calculate rows needed

            fig, ax_subplots = plt.subplots(nrows, ncols, figsize=(6*ncols, 5*nrows), squeeze=False) # squeeze=False always returns 2D array
            ax_flat = ax_subplots.ravel() # Flatten to easily iterate

            plot_colors = ['skyblue', 'lightgreen', 'salmon', 'gold', 'lightcoral', 'lightskyblue'] # More colors if needed

            for i, (metric_name_plot, metric_value_plot) in enumerate(valid_metrics_for_plot.items()):
                ax_curr = ax_flat[i]
                ax_curr.bar([metric_name_plot.replace(" (%)","\n(%)")], [metric_value_plot], color=plot_colors[i % len(plot_colors)])
                ax_curr.set_title(metric_name_plot)

                text_val = ""
                if "MACs" in metric_name_plot or "Params" in metric_name_plot:
                    text_val = f"{metric_value_plot:,.0f}\n({metric_value_plot:.2e})"
                    ax_curr.set_yscale('log') # Log scale for MACs/Params makes sense
                elif "Loss" in metric_name_plot:
                     text_val = f"{metric_value_plot:.4f}"
                elif "Acc" in metric_name_plot: # Accuracy
                    text_val = f"{metric_value_plot:.2f}%"
                    ax_curr.set_ylim(min(0, metric_value_plot-10 if metric_value_plot else 0), max(100, metric_value_plot+5 if metric_value_plot else 100) + (5 if metric_value_plot==100 else 0) ) # Dynamic Y lim for acc

                ax_curr.text(0, metric_value_plot, text_val, ha='center', va='bottom' if metric_value_plot > 0 else 'top')
                ax_curr.grid(True, axis='y', linestyle='--', alpha=0.7)


            # Hide any unused subplots
            for j in range(num_valid_metrics, nrows * ncols):
                fig.delaxes(ax_flat[j])

            fig.suptitle(f"Initial Loaded Model Metrics: {model_name}", fontsize=16)
            plt.tight_layout(rect=[0, 0, 1, 0.95])
            plot_path = os.path.join(output_dir, f"{model_name}_metrics_summary_chart.png")
            plt.savefig(plot_path)
            plt.close(fig)
            print(f"✅ Metrics summary chart saved to: {plot_path}")
        else:
            print("No valid metrics to plot for initial model summary chart.")
    else:
        print("No evaluation metrics provided for initial model summary chart.")

In [9]:
def plot_initial_model_training_summary(training_history_plot, model_name_str, output_dir_init_plots):
    # (Identical to previous full script version)
    os.makedirs(output_dir_init_plots, exist_ok=True); epochs_run_plot = len(training_history_plot['train_loss'])
    if epochs_run_plot == 0: print(f"No history for {model_name_str}. Skip plots."); return
    epochs_axis_plot = range(1, epochs_run_plot + 1); plt.figure(figsize=(10, 5))
    plt.plot(epochs_axis_plot, training_history_plot['train_loss'], 'bo-', label='Train Loss')
    if training_history_plot.get('val_loss') and any(v is not None for v in training_history_plot['val_loss']):
        v_loss_epochs = [ep for ep, lo in zip(epochs_axis_plot, training_history_plot['val_loss']) if lo is not None]; v_loss_vals = [lo for lo in training_history_plot['val_loss'] if lo is not None]
        if v_loss_vals: plt.plot(v_loss_epochs, v_loss_vals, 'ro-', label='Val Loss')
    plt.title(f'{model_name_str}: Loss vs. Epochs'); plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
    plt.savefig(os.path.join(output_dir_init_plots, f'{model_name_str}_loss_vs_epochs.png')); plt.close(); print(f"✅ {model_name_str} Loss/Ep saved to {output_dir_init_plots}")
    fig_la, axs_la = plt.subplots(1, 2 if (training_history_plot.get('val_loss') and any(v is not None for v in training_history_plot['val_loss'])) else 1, figsize=(12, 5), squeeze=False)
    ax1_la = axs_la[0,0]; ax1_la.plot(training_history_plot['train_acc'], training_history_plot['train_loss'], 'bo-', label='Tr L/Acc')
    if epochs_run_plot > 0: ax1_la.scatter(training_history_plot['train_acc'][0], training_history_plot['train_loss'][0], c='g', s=50, zorder=5, label='Start Tr'); ax1_la.scatter(training_history_plot['train_acc'][-1], training_history_plot['train_loss'][-1], c='r', s=50, zorder=5, label='End Tr')
    ax1_la.set_title('Train Loss vs. Acc'); ax1_la.set_xlabel('Tr Acc %'); ax1_la.set_ylabel('Tr Loss'); ax1_la.legend(); ax1_la.grid(True)
    if training_history_plot.get('val_loss') and any(v is not None for v in training_history_plot['val_loss']) and axs_la.shape[1] > 1:
        ax2_la = axs_la[0,1]; v_indices = [i for i, (lo, ac) in enumerate(zip(training_history_plot['val_loss'], training_history_plot['val_acc'])) if lo is not None and ac is not None]
        if v_indices:
            v_acc_p = [training_history_plot['val_acc'][i] for i in v_indices]; v_loss_p = [training_history_plot['val_loss'][i] for i in v_indices]; ax2_la.plot(v_acc_p, v_loss_p, 'ro-', label='Val L/Acc')
            if v_acc_p: ax2_la.scatter(v_acc_p[0], v_loss_p[0], c='lime', s=50, zorder=5, label='Start Val'); ax2_la.scatter(v_acc_p[-1], v_loss_p[-1], c='darkred', s=50, zorder=5, label='End Val')
            ax2_la.set_title('Val Loss vs. Acc'); ax2_la.set_xlabel('Val Acc %'); ax2_la.set_ylabel('Val Loss'); ax2_la.legend(); ax2_la.grid(True)
    fig_la.suptitle(f'{model_name_str}: Loss vs. Acc Dynamics', fontsize=16); plt.tight_layout(rect=[0,0,1,0.96]); plt.savefig(os.path.join(output_dir_init_plots, f'{model_name_str}_loss_vs_accuracy.png')); plt.close(); print(f"✅ {model_name_str} Loss/Acc saved to {output_dir_init_plots}")
def plot_finetuning_curves(history_ft, plot_title_suffix_ft, output_dir_plots_ft, model_macs_val_ft):
    os.makedirs(output_dir_plots_ft, exist_ok=True); actual_epochs = len(history_ft['train_loss']); epochs_range = range(1, actual_epochs + 1); plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1); plt.plot(epochs_range, history_ft['train_loss'][:actual_epochs], 'bo-', label='Tr L');
    if history_ft.get('val_loss') and any(v is not None for v in history_ft['val_loss']): plt.plot(epochs_range, history_ft['val_loss'][:actual_epochs], 'ro-', label='Val L')
    plt.title(f'L ({plot_title_suffix_ft})\nMACs:{model_macs_val_ft:.2e}'); plt.xlabel('Ep'); plt.ylabel('L'); plt.legend(); plt.grid(True)
    plt.subplot(1, 2, 2); plt.plot(epochs_range, history_ft['train_acc'][:actual_epochs], 'bo-', label='Tr Acc');
    if history_ft.get('val_acc') and any(v is not None for v in history_ft['val_acc']): plt.plot(epochs_range, history_ft['val_acc'][:actual_epochs], 'ro-', label='Val Acc')
    plt.title(f'Acc ({plot_title_suffix_ft})\nMACs:{model_macs_val_ft:.2e}'); plt.xlabel('Ep'); plt.ylabel('Acc %'); plt.legend(); plt.grid(True)
    plt.tight_layout(); safe_suffix_ft = plot_title_suffix_ft.replace(' ', '_').replace('/', '_').replace(':', '_'); plt.savefig(os.path.join(output_dir_plots_ft, f'ft_curves_{safe_suffix_ft}.png')); plt.close(); print(f"✅ FT curves for {plot_title_suffix_ft} saved.")
def plot_metrics_vs_ratio_all_strategies(results_plot, ratios_plot, out_dir_compare):
    os.makedirs(out_dir_compare, exist_ok=True); strat_plot = list(results_plot.keys())
    if not strat_plot: print("No strats for ratio plots."); return
    ratios_labels_plot = [f"{r:.1f}" for r in ratios_plot]; num_ratios_p = len(ratios_plot); num_strat_p = len(strat_plot); group_w = 0.8; bar_w = group_w / num_strat_p; idx_p = np.arange(num_ratios_p); cmap_p = plt.colormaps.get_cmap('tab10')
    metrics_to_plot = [('macs', 'MACs (Log Scale)', True), ('loss', 'Avg. Test Loss', False), ('accuracy', 'Accuracy (%)', False)]
    for metric_key, y_label, use_log_scale in metrics_to_plot:
        plt.figure(figsize=(max(12, int(1.8*num_ratios_p)), 7));
        for i, s_name_p in enumerate(strat_plot):
            vals_p = [results_plot[s_name_p].get(r_val, {}).get(metric_key, np.nan) for r_val in ratios_plot]; bar_pos_p = idx_p - (group_w/2) + (i*bar_w) + (bar_w/2)
            plt.bar(bar_pos_p, vals_p, bar_w, label=s_name_p, color=cmap_p(i % cmap_p.N))
        plt.xlabel('Pruning Ratio (ch_sparsity)'); plt.ylabel(y_label); plt.title(f'{y_label.split("(")[0].strip()} vs. Pruning Ratio (Fine-tuned)'); plt.xticks(idx_p, ratios_labels_plot)
        if use_log_scale: plt.yscale('log'); plt.grid(True, which="both", ls="-", alpha=0.3)
        else: plt.grid(True, alpha=0.3)
        if metric_key == 'accuracy': plt.ylim(0,100)
        plt.legend(title="Strategy", bbox_to_anchor=(1.02,1), loc='upper left'); plt.tight_layout(rect=[0,0,0.85,1]); plt.savefig(os.path.join(out_dir_compare, f'{metric_key.upper()}_vs_Ratio_by_Strategy.png')); plt.close(); print(f"✅ {metric_key.upper()} vs. Ratio plot saved to {out_dir_compare}")


### Main Workflow

In [10]:
def main():
    print(f"Using device: {DEVICE}")
    cfg = {
        'strategies': {
            'MagnitudeL2': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.MagnitudeImportance(p=2)},
            'BNScale': {'pruner': tp.pruner.BNScalePruner, 'importance': tp.importance.BNScaleImportance()},
            'Random': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.RandomImportance()},
        },
        'pruning_ratios_to_test': [0.0, 0.2, 0.5, 0.7],
        'iterative_steps_pruner_general': 5,
        'iterative_steps_taylor_pruning': 5,
        'fine_tune_epochs': 30,
        'early_stopping_patience': 7,
        'early_stopping_metric': 'val_loss',
        'load_best_model_on_early_stop': True,
        'cifar10_data_root': './data',
        'val_split_for_loader': 0.1, 'data_loader_seed': 42,
        'local_mobilenetv2_pth_path': './base/mobilenet_v2-b0353104.pth',  # UPDATE THIS
        'output_dir_base': f'./output_{MODEL_BASE_NAME}_local_pth_final',
        'num_classes': 10, 'batch_size': 128,
        'learning_rate_finetune': 0.0001,
    }

    # Setup directories and check .pth file
    abs_local_pth_path = os.path.abspath(cfg['local_mobilenetv2_pth_path'])
    if not os.path.exists(abs_local_pth_path):
        print(f"ERROR: PTH file not found: {abs_local_pth_path}"); return
    os.makedirs(cfg['output_dir_base'], exist_ok=True)
    initial_model_info_dir = os.path.join(cfg['output_dir_base'], "initial_model_info")
    os.makedirs(initial_model_info_dir, exist_ok=True)
    overall_comparison_plots_out_dir = os.path.join(cfg['output_dir_base'], "overall_comparison_plots")
    os.makedirs(overall_comparison_plots_out_dir, exist_ok=True)

    # Data loaders
    print("Initializing DataLoaders...")
    train_loader, val_loader, test_loader = get_data_loaders(
        data_dir_path=cfg['cifar10_data_root'], batch_size=cfg['batch_size'],
        val_split=cfg['val_split_for_loader'], seed=cfg['data_loader_seed']
    )
    example_input_cpu_onnx_main = torch.randn(1, 3, 32, 32)
    example_input_dev_main = example_input_cpu_onnx_main.to(DEVICE)
    criterion_main = nn.CrossEntropyLoss().to(DEVICE)

    # Load and adapt the model
    print(f"--- Loading and Adapting Model from Local PTH: {abs_local_pth_path} ---")
    temp_model_for_loading = get_torchvision_mobilenetv2_structure(num_classes_local=1000, device_to_map=DEVICE)
    try:
        temp_model_for_loading.load_state_dict(torch.load(abs_local_pth_path, map_location=DEVICE))
    except RuntimeError:
        print("Attempting to load local .pth into a 10-class structure directly...")
        try:
            temp_model_for_loading = get_torchvision_mobilenetv2_structure(num_classes_local=cfg['num_classes'], device_to_map=DEVICE)
            temp_model_for_loading.load_state_dict(torch.load(abs_local_pth_path, map_location=DEVICE))
        except Exception as e2:
            print(f"CRITICAL ERROR: Failed to load local .pth file. Error: {e2}"); return
    final_classifier_layer_tv = None
    if isinstance(temp_model_for_loading.classifier, nn.Sequential) and len(temp_model_for_loading.classifier) > 0:
        idx = -1
        if isinstance(temp_model_for_loading.classifier[idx], nn.Linear): final_classifier_layer_tv = temp_model_for_loading.classifier[idx]
        elif len(temp_model_for_loading.classifier) > 1 and isinstance(temp_model_for_loading.classifier[-2], nn.Linear): final_classifier_layer_tv = temp_model_for_loading.classifier[-2]; idx=-2
    if final_classifier_layer_tv and final_classifier_layer_tv.out_features != cfg['num_classes']:
        num_ftrs = final_classifier_layer_tv.in_features
        if idx == -1: temp_model_for_loading.classifier[-1] = nn.Linear(num_ftrs, cfg['num_classes'])
        else: temp_model_for_loading.classifier[-2] = nn.Linear(num_ftrs, cfg['num_classes'])
    initial_dense_model_ready = temp_model_for_loading.to(DEVICE)

    # **Added: Initial Training**
    print("--- Training the initial model on CIFAR-10 ---")
    initial_dense_model_ready, initial_model_hist = train_model(
        initial_dense_model_ready, train_loader, criterion_main,
        optim.Adam(initial_dense_model_ready.parameters(), lr=cfg['learning_rate_finetune']),
        DEVICE, cfg['fine_tune_epochs'], val_loader=val_loader,
        log_prefix_train="Initial Training", es_patience=cfg['early_stopping_patience'],
        es_metric=cfg['early_stopping_metric'], load_best_es=cfg['load_best_model_on_early_stop']
    )
    trained_initial_model_path = os.path.join(cfg['output_dir_base'], f"{MODEL_BASE_NAME}_trained_initial.pth")
    save_model(initial_dense_model_ready, trained_initial_model_path, example_input_cpu_onnx_main)
    print(f"Trained initial model saved to: {trained_initial_model_path}")

    # **Modified: Baseline Evaluation with Trained Model**
    print(f"\n--- Evaluating Baseline {MODEL_BASE_NAME} (Trained Initial) on TEST SET ---")
    baseline_metrics = evaluate_model_with_loss(initial_dense_model_ready, test_loader, example_input_dev_main, criterion_main, DEVICE)
    print(f"Baseline Metrics (Trained Initial, on Test Set): {baseline_metrics}")

    # Summarize the trained model
    summarize_loaded_initial_model(
        model_instance=initial_dense_model_ready,
        model_name=f"{MODEL_BASE_NAME}_TrainedInitial",
        eval_metrics=baseline_metrics,
        example_input_for_summary=example_input_dev_main,
        output_dir=initial_model_info_dir
    )
    plot_initial_model_training_summary(initial_model_hist, f"{MODEL_BASE_NAME}_InitialTraining", initial_model_info_dir)

    # Initialize results with baseline
    results_all = {s_name: {0.0: baseline_metrics} for s_name in cfg['strategies'].keys()}

    # Pruning and fine-tuning loop
    for strat_name, strat_details in cfg['strategies'].items():
        strat_out_dir = os.path.join(cfg['output_dir_base'], "strategies_results", strat_name)
        os.makedirs(strat_out_dir, exist_ok=True)
        strat_finetune_plots_out_dir = os.path.join(strat_out_dir, "finetuning_plots")
        os.makedirs(strat_finetune_plots_out_dir, exist_ok=True)

        for ratio_val in cfg['pruning_ratios_to_test']:
            if ratio_val == 0.0: continue
            ratio_fname_str = f"R{ratio_val:.1f}".replace('.', 'p')
            print(f"\n--- Proc: Strat '{strat_name}', Ratio {ratio_val:.2f} ({ratio_fname_str}) ---")

            # **Modified: Load the trained model**
            model_to_prune_strat_ratio = load_model_state(
                get_torchvision_mobilenetv2_structure,
                trained_initial_model_path,  # Use the trained model
                DEVICE,
                num_classes_for_model_struct=cfg['num_classes']
            )
            ignored_layers_tv = get_ignored_layer_for_torchvision_mobilenetv2(model_to_prune_strat_ratio)

            print(f"--- Pruning: {strat_name}, ratio {ratio_val:.2f} ---")
            pruned_model_strat_ratio = gr_prune_model_by_ratio(
                model_to_prune_strat_ratio, example_input_dev_main, strat_details, ratio_val,
                explicit_ignored_layers_list=ignored_layers_tv, iterative_steps_general_pruner=cfg['iterative_steps_pruner_general'],
                iterative_steps_taylor_pruner=cfg['iterative_steps_taylor_pruning']
            )
            pruned_bf_ft_fname = f"{MODEL_BASE_NAME}_{strat_name}_{ratio_fname_str}_pruned_bf_ft"
            save_model(pruned_model_strat_ratio, os.path.join(strat_out_dir, f"{pruned_bf_ft_fname}.pth"), example_input_cpu_onnx_main)
            metrics_bf_ft = evaluate_model_with_loss(pruned_model_strat_ratio, test_loader, example_input_dev_main, criterion_main, DEVICE)
            print(f"Metrics '{strat_name}' @ {ratio_fname_str} (Pruned, Pre-FT, Test Set): {metrics_bf_ft}")

            print(f"--- Fine-tuning: {strat_name}, ratio {ratio_fname_str} ---")
            fine_tuned_model_strat_ratio, ft_hist = train_model(
                pruned_model_strat_ratio, train_loader, criterion_main,
                optim.Adam(pruned_model_strat_ratio.parameters(), lr=cfg['learning_rate_finetune']),
                DEVICE, cfg['fine_tune_epochs'], val_loader=val_loader,
                log_prefix_train=f"{strat_name}_{ratio_fname_str}", es_patience=cfg['early_stopping_patience'],
                es_metric=cfg['early_stopping_metric'], load_best_es=cfg['load_best_model_on_early_stop']
            )
            macs_ft_val, _ = calculate_macs_params(fine_tuned_model_strat_ratio, example_input_dev_main)
            plot_finetuning_curves(ft_hist, f"{MODEL_BASE_NAME}_{strat_name}_{ratio_fname_str}", strat_finetune_plots_out_dir, macs_ft_val)
            final_metrics_strat_ratio = evaluate_model_with_loss(fine_tuned_model_strat_ratio, test_loader, example_input_dev_main, criterion_main, DEVICE)
            results_all[strat_name][ratio_val] = final_metrics_strat_ratio
            print(f"Metrics '{strat_name}' @ {ratio_fname_str} (Fine-tuned, Test Set): {final_metrics_strat_ratio}")
            final_model_fname = f"{MODEL_BASE_NAME}_{strat_name}_{ratio_fname_str}_final"
            save_model(fine_tuned_model_strat_ratio, os.path.join(strat_out_dir, f"{final_model_fname}.pth"), example_input_cpu_onnx_main)

    plot_metrics_vs_ratio_all_strategies(results_all, cfg['pruning_ratios_to_test'], overall_comparison_plots_out_dir)
    print(f"\nAll experiments complete! Outputs in: {os.path.abspath(cfg['output_dir_base'])}")

In [11]:
if __name__ == "__main__":
    main()

Using device: cuda
Initializing DataLoaders...
Attempting to load CIFAR-10 from pre-downloaded directory: /home/muis/thesis/github-repo/master-thesis/cnn/mobile_net_v2/data
DataLoaders: Train 45000, Val 5000, Test 10000
--- Loading and Adapting Model from Local PTH: /home/muis/thesis/github-repo/master-thesis/cnn/mobile_net_v2/base/mobilenet_v2-b0353104.pth ---
--- Training the initial model on CIFAR-10 ---
E 1/30 (Initial Training): Tr L=1.2105, Acc=57.37%, Val L=0.9065, Acc=68.20% (New best)
E 2/30 (Initial Training): Tr L=0.7921, Acc=72.04%, Val L=0.7653, Acc=72.92% (New best)
E 3/30 (Initial Training): Tr L=0.6308, Acc=77.93%, Val L=0.7290, Acc=74.86% (New best)
E 4/30 (Initial Training): Tr L=0.5157, Acc=82.02%, Val L=0.6947, Acc=75.80% (New best)
E 5/30 (Initial Training): Tr L=0.4289, Acc=84.89%, Val L=0.6902, Acc=76.60% (New best)
E 6/30 (Initial Training): Tr L=0.3407, Acc=88.08%, Val L=0.6908, Acc=77.58%
E 7/30 (Initial Training): Tr L=0.2723, Acc=90.63%, Val L=0.7793, Acc=75



Pruning for ratio 0.20 done: MACs 0.006 G (Reduced by 8.05%)
✅ Model state_dict saved to ./output_mobilenet_v2-b0353104_local_pth_final/strategies_results/MagnitudeL2/mobilenet_v2-b0353104_MagnitudeL2_R0p2_pruned_bf_ft.pth
✅ Model saved as ONNX to ./output_mobilenet_v2-b0353104_local_pth_final/strategies_results/MagnitudeL2/mobilenet_v2-b0353104_MagnitudeL2_R0p2_pruned_bf_ft.onnx
Metrics 'MagnitudeL2' @ R0p2 (Pruned, Pre-FT, Test Set): {'macs': 5997155.0, 'params': 2057471, 'size_mib': 7.848628997802734, 'accuracy': 61.54, 'loss': 1.1281787590026855}
--- Fine-tuning: MagnitudeL2, ratio R0p2 ---
E 1/30 (MagnitudeL2_R0p2): Tr L=0.5731, Acc=79.80%, Val L=0.6849, Acc=75.46% (New best)
E 2/30 (MagnitudeL2_R0p2): Tr L=0.4242, Acc=84.80%, Val L=0.7062, Acc=76.78%
E 3/30 (MagnitudeL2_R0p2): Tr L=0.3441, Acc=87.79%, Val L=0.6939, Acc=77.78%
E 4/30 (MagnitudeL2_R0p2): Tr L=0.2697, Acc=90.56%, Val L=0.7427, Acc=77.20%
E 5/30 (MagnitudeL2_R0p2): Tr L=0.2089, Acc=92.74%, Val L=0.7967, Acc=77.20%
E 