In [1]:
# !unzip drive/MyDrive/Pets-data/images.zip
# !cp -r drive/MyDrive/Pets-data/annotations .
# !ls

In [22]:
import os
import yaml
from typing import Tuple, List

import torch
import numpy as np
import torchvision
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from pydantic import BaseModel
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [2]:
# Define constants

CONFIG_FILE_PATH = 'config_optuna.yaml'
IMAGE_COL_IDX = 0
CLASS_ID_COL_IDX = 1
SPECIES_COL_IDX = 2
POSSIBLE_NUM_CLASSES = {2, 37}

In [3]:
class AdamOptimizerConfig(BaseModel):
    lr: float
    weight_decay: float

In [4]:
class LastLayerTrainingConfig(BaseModel):
    unfreeze_epoch: int
    lr: float
    weight_decay: float
    use_train_mode: bool

In [31]:
class Config(BaseModel):
    device: str
    num_classes: int
    batch_size: int
    max_num_epochs: int
    patience: int
    # adam_optimizer_config: AdamOptimizerConfig
    # num_batch_norm_layers_to_train_params: int
    # num_batch_norm_layers_to_update_running_stats: int
    # train_earlier_layers_delay: int
    # n_hidden_layers_to_train: int
    last_layers_training_configs: List[LastLayerTrainingConfig]
    use_pseudo_labelling: int
    labelled_data_ratio: float
    T_1: int
    T_2: int
    ALPHA_F: float

In [32]:
with open(CONFIG_FILE_PATH, encoding='utf-8') as f:
    config_dict = yaml.load(f, Loader=yaml.FullLoader)

config = Config.model_validate(config_dict)

assert config.num_classes in POSSIBLE_NUM_CLASSES
# assert 0 <= config.num_batch_norm_layers_to_train_params <= 36  # 36 batch norm layers in resnet34
# assert 0 <= config.num_batch_norm_layers_to_update_running_stats <= 36  # 36 batch norm layers in resnet34

In [7]:
class ImageDataset(Dataset):
    def __init__(self, filenames: List[str], labels: List[int], use_augmentations: bool) -> None:
        self.filenames = filenames
        self.labels = labels
        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
                # transforms.CenterCrop(size=224),
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(dtype=torch.float),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                # torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.CenterCrop(size=256),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation(degrees=(-30, 30), expand=True),
                transforms.RandomResizedCrop(size=224),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(os.path.join('images', f'{self.filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(config.device), torch.tensor(label).to(config.device)

In [12]:
class ImageDatasetPseudoLabels(Dataset):
    def __init__(self, labelled_filenames: List[str], labels: List[int], unlabelled_filenames: List[str], use_augmentations: bool) -> None:
        self.labelled_filenames = labelled_filenames
        self.labels = labels
        self.unlabelled_filenames = unlabelled_filenames

        assert len(labelled_filenames) == len(labels), 'labels must have the same size as labelled_filenames'

        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
                # transforms.CenterCrop(size=224),
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(dtype=torch.float),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                # torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.CenterCrop(size=256),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation(degrees=(-30, 30), expand=True),
                transforms.RandomResizedCrop(size=224),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )

    def __len__(self) -> int:
        return len(self.labelled_filenames) + len(self.unlabelled_filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if idx >= len(self.labelled_filenames): # Returning an unlabelled image
            image = Image.open(os.path.join('images', f'{self.unlabelled_filenames[idx - len(self.labelled_filenames)]}.jpg')).convert('RGB')
            transformed_img = self.transformation(image)
            return transformed_img.to(config.device), torch.tensor(-1).float().to(config.device)

        image = Image.open(os.path.join('images', f'{self.labelled_filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(config.device), torch.tensor(label).to(config.device)

In [13]:
def get_image_names_and_labels(annotations_file_path: str, num_classes: int) -> Tuple[List[str], List[int]]:
    filenames: List[str] = []
    labels: List[int] = []

    with open(annotations_file_path, encoding='utf-8') as f:
        lines = f.readlines()

    label_col_idx = SPECIES_COL_IDX if num_classes == 2 else CLASS_ID_COL_IDX

    for line in lines:
        line_split = line.split()
        filenames.append(line_split[IMAGE_COL_IDX])
        labels.append(int(line_split[label_col_idx]) - 1)

    return filenames, labels

In [14]:
# def get_batch_norm_layers(model: nn.Module) -> List[nn.Module]:
#     return [
#         module
#         for module in model.modules()
#         if isinstance(module, nn.BatchNorm2d)
#     ]

In [15]:
# def get_pretrained_model_with_trainable_last_layer(num_batch_norm_layers_to_train_params: int) -> nn.Module:
#     model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

#     for param in model.parameters():
#         param.requires_grad = False

#     batch_norm_layers = get_batch_norm_layers(model)
#     batch_norm_layers_to_train_params = batch_norm_layers[-num_batch_norm_layers_to_train_params:] if num_batch_norm_layers_to_train_params else []

#     for batch_norm_layer in batch_norm_layers_to_train_params:
#         for param in batch_norm_layer.parameters():
#             param.requires_grad = True

#     model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

#     return model.to(config.device)

In [16]:
# def get_pretrained_model_with_trainable_n_layers(num_batch_norm_layers_to_train_params: int, number_of_hidden_layers_to_train: int=0) -> nn.Module:
#     model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

#     for param in model.parameters():
#         param.requires_grad = False

#     batch_norm_layers = get_batch_norm_layers(model)
#     batch_norm_layers_to_train_params = batch_norm_layers[-num_batch_norm_layers_to_train_params:] if num_batch_norm_layers_to_train_params else []

#     for batch_norm_layer in batch_norm_layers_to_train_params:
#         for param in batch_norm_layer.parameters():
#             param.requires_grad = True

#     model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

#     trainable_layers = [layer for layer in model.modules() if not isinstance(layer, torchvision.models.resnet.ResNet) and not isinstance(layer, torchvision.models.resnet.BasicBlock) and not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and len(list(layer.parameters())) > 0]
#     trainable_layers = list(reversed(trainable_layers[:-1]))

#     for l in trainable_layers[:number_of_hidden_layers_to_train]:
#         for p in l.parameters():
#             p.requires_grad = True

#     return model.to(config.device)

In [17]:
def get_pretrained_model_and_model_trainable_layers() -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

    for param in model.parameters():
        param.requires_grad = False

    model_trainable_layers = [
        layer
        for layer in model.modules()
        if (not isinstance(layer, torchvision.models.resnet.ResNet) and
            not isinstance(layer, torchvision.models.resnet.BasicBlock) and
            not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and
            len(list(layer.parameters())) > 0)
    ]

    return model.to(config.device), model_trainable_layers

In [18]:
# def make_hidden_layers_trainable(model: nn.Module, number_of_hidden_layers_to_train: int) -> nn.Module:
#     trainable_layers = [layer for layer in model.modules() if not isinstance(layer, torchvision.models.resnet.ResNet) and not isinstance(layer, torchvision.models.resnet.BasicBlock) and not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and len(list(layer.parameters())) > 0]
#     trainable_layers = list(reversed(trainable_layers[:-1]))

#     for l in trainable_layers[:number_of_hidden_layers_to_train]:
#         for p in l.parameters():
#             p.requires_grad = True

#     return model

In [19]:
def get_model_accuracy(model: nn.Module, data_loader: DataLoader) -> float:
    correct_predictions_cnt = 0
    total_predictions_cnt = 0
    model.eval()
    with torch.no_grad():
        # for inputs, labels in tqdm(data_loader, desc='Computing accuracy'):
        for inputs, labels in data_loader:
            outputs = model(inputs)
            correct_predictions_cnt += (torch.argmax(outputs, axis=1) == labels).sum()
            total_predictions_cnt += len(outputs)
    return correct_predictions_cnt / total_predictions_cnt

In [20]:
# def train_single_epoch_old(
#         model: nn.Module,
#         train_data_loader: DataLoader,
#         criterion: nn.Module,
#         optimizer: torch.optim.Optimizer,
#         num_batch_norm_layers_to_update_running_stats: int,
#         ) -> float:
#     model.train()
#     batch_norm_layers = get_batch_norm_layers(model)
#     batch_norm_layers_to_not_update_running_stats = (
#         batch_norm_layers[:-num_batch_norm_layers_to_update_running_stats]
#         if num_batch_norm_layers_to_update_running_stats
#         else batch_norm_layers
#     )
#     for batch_norm_layer in batch_norm_layers_to_not_update_running_stats:
#         batch_norm_layer.eval()
#     train_loss_sum = 0.0
#     train_samples_cnt = 0
#     for inputs, labels in tqdm(train_data_loader, desc='Training model'):
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#         train_loss_sum += loss.item() * len(outputs)
#         train_samples_cnt += len(outputs)
#     return train_loss_sum / train_samples_cnt

In [21]:
def train_single_epoch(
        model: nn.Module,
        model_trainable_layers: List[nn.Module],
        train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        last_layers_training_configs: List[LastLayerTrainingConfig],
        epoch: int,
        ) -> float:
    model.eval()
    for layer_reverse_idx, layer_training_config in enumerate(last_layers_training_configs):
        layer = model_trainable_layers[-layer_reverse_idx - 1]
        if layer_training_config.unfreeze_epoch <= epoch:
            for param in layer.parameters():
                param.requires_grad = True
            if layer_training_config.use_train_mode:
                layer.train()
        if layer_reverse_idx and layer_training_config.unfreeze_epoch == epoch:
            optimizer.add_param_group({
                'params': layer.parameters(),
                'lr': layer_training_config.lr,
                'weight_decay': layer_training_config.weight_decay,
            })
    train_loss_sum = 0.0
    train_samples_cnt = 0
    # for inputs, labels in tqdm(train_data_loader, desc='Training model'):
    for inputs, labels in train_data_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)
    return train_loss_sum / train_samples_cnt

In [23]:
def alpha(t):
    if t < config.T_1:
        return 0
    elif config.T_1 <= t and t < config.T_2:
        return ((t - config.T_1) / (config.T_2-config.T_1)) * config.ALPHA_F
    return config.ALPHA_F


def train_single_epoch_pseudo(
        model: nn.Module,
        model_trainable_layers: List[nn.Module],
        train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        last_layers_training_configs: List[LastLayerTrainingConfig],
        epoch: int,
        ) -> float:
    model.eval()
    for layer_reverse_idx, layer_training_config in enumerate(last_layers_training_configs):
        layer = model_trainable_layers[-layer_reverse_idx - 1]
        if layer_training_config.unfreeze_epoch <= epoch:
            for param in layer.parameters():
                param.requires_grad = True
            if layer_training_config.use_train_mode:
                layer.train()
        if layer_reverse_idx and layer_training_config.unfreeze_epoch == epoch:
            optimizer.add_param_group({
                'params': layer.parameters(),
                'lr': layer_training_config.lr,
                'weight_decay': layer_training_config.weight_decay,
            })
    train_loss_sum = 0.0
    train_samples_cnt = 0

    for inputs, labels in tqdm(train_data_loader, desc='Training model'):
        optimizer.zero_grad()

        labelled_idxs = torch.tensor([x[0] for x in enumerate(labels) if not x[1] == -1]).long().to(config.device)
        unlabelled_idxs = torch.tensor([x[0] for x in enumerate(labels) if x[1] == -1]).long().to(config.device)

        labelled_labels = torch.index_select(labels, 0, labelled_idxs).long()

        outputs = model(inputs)

        labelled_outputs = torch.index_select(outputs, 0, labelled_idxs)
        unlabelled_outputs = torch.index_select(outputs, 0, unlabelled_idxs)

        unlabelled_labels = torch.argmax(unlabelled_outputs, axis=1)

        labelled_loss = criterion(labelled_outputs, labelled_labels)
        unlabelled_loss = criterion(unlabelled_outputs, unlabelled_labels)
        loss = 0

        if not np.isnan(labelled_loss.item()):
            loss = loss + labelled_loss
        if not np.isnan(unlabelled_loss.item()):
            loss = loss + alpha(epoch) * unlabelled_loss

        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)

    return train_loss_sum / train_samples_cnt

In [24]:
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> Tuple[nn.Module, torch.optim.Optimizer]:
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, 'checkpoint.pt')

In [25]:
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [26]:
filenames_trainval, labels_trainval = get_image_names_and_labels('annotations/trainval.txt', num_classes=config.num_classes)
filenames_test, labels_test = get_image_names_and_labels('annotations/test.txt', num_classes=config.num_classes)
filenames_train, filenames_val, labels_train, labels_val = train_test_split(filenames_trainval, labels_trainval, test_size=0.2, stratify=labels_trainval)


if config.use_pseudo_labelling:
    filenames_train, unlabelled_filenames_train, labels_train, _ = train_test_split(filenames_train, labels_train, train_size=config.labelled_data_ratio, stratify=labels_train)
    dataset_train = ImageDatasetPseudoLabels(filenames_train, labels_train, unlabelled_filenames_train, use_augmentations=True)
    labelled_dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True)
    labelled_train_data_loader = DataLoader(labelled_dataset_train, batch_size=config.batch_size, shuffle=True)
else:
    dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True)


dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True)
dataset_val = ImageDataset(filenames_val, labels_val, use_augmentations=False)
dataset_test = ImageDataset(filenames_test, labels_test, use_augmentations=False)

train_data_loader = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True)
val_data_loader = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=False)
test_data_loader = DataLoader(dataset_test, batch_size=config.batch_size, shuffle=False)

In [27]:
import optuna

  from .autonotebook import tqdm as notebook_tqdm


In [38]:
def train_and_get_max_val_accuracy(config: Config) -> float:
    model, model_trainable_layers = get_pretrained_model_and_model_trainable_layers()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model_trainable_layers[-1].parameters(),
        lr=config.last_layers_training_configs[0].lr,
        weight_decay=config.last_layers_training_configs[0].weight_decay,
    )

    max_val_accuracy = float('-inf')
    argmax_epoch = -1

    for epoch in range(config.max_num_epochs):
        # print(f'Epoch #{epoch}:')
        train_loss = train_single_epoch_pseudo(
            model, model_trainable_layers, train_data_loader, criterion, optimizer, config.last_layers_training_configs, epoch
        )
        print(f'Train loss: {train_loss}')
        val_accuracy = get_model_accuracy(model, val_data_loader)
        if val_accuracy > max_val_accuracy:
            print(f'Validation accuracy: {100 * val_accuracy:.2f}% (new best)')
            max_val_accuracy = val_accuracy
            argmax_epoch = epoch
            # save_checkpoint(model, optimizer)
            # print('Checkpoint saved')
        else:
            print(f'Validation accuracy: {100 * val_accuracy:.2f}% (worse than {100 * max_val_accuracy:.2f}% of epoch {argmax_epoch})')
            if epoch > argmax_epoch + config.patience:
                print(f'Early stopping')
                break
        # # Start training earlier layers
        # if epoch == config.train_earlier_layers_delay and config.n_hidden_layers_to_train > 0:
        #     print (f'Making last {config.n_hidden_layers_to_train} hidden layers trainable')
        #     make_hidden_layers_trainable(model, config.n_hidden_layers_to_train)
        #     optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr / 100, weight_decay=config.adam_optimizer_config.weight_decay)
        # print()
    return max_val_accuracy

In [43]:
def objective(trial: optuna.Trial):
    batch_size_exp = trial.suggest_int("batch_size_exp", 3, 7)
    batch_size = 2 ** batch_size_exp
    config = Config(
        device='cuda:0',
        num_classes=37,
        batch_size=batch_size,
        max_num_epochs=40,
        patience=10,
        last_layers_training_configs=[],
        # pseudo-labeling
        use_pseudo_labelling=1,
        labelled_data_ratio=0.1,
        T_1=1_000,
        T_2=1_000,
        ALPHA_F=3,
    )
    last_layer_lr_exp = trial.suggest_float("last_layer_lr_exp", -4.0, -2.0)
    last_layer_weight_decay_exp = trial.suggest_float("last_layer_weight_decay_exp", -4.0, -2.0)
    last_layer_lr = 10.0 ** last_layer_lr_exp
    last_layer_weight_decay = 10.0 ** last_layer_weight_decay_exp
    config.last_layers_training_configs.append(LastLayerTrainingConfig(
        unfreeze_epoch=0,
        lr=last_layer_lr,
        weight_decay=last_layer_weight_decay,
        use_train_mode=True,
    ))
    second_last_layer_unfreeze_epoch = trial.suggest_int("second_last_layer_unfreeze_epoch", 1, 10)
    second_last_layer_lr_exp = trial.suggest_float("second_last_layer_lr_exp", -5.0, -1.0)
    second_last_layer_weight_decay_exp = trial.suggest_float("second_last_layer_weight_decay_exp", -5.0, -1.0)
    second_last_layer_lr = 10.0 ** second_last_layer_lr_exp
    second_last_layer_weight_decay = 10.0 ** second_last_layer_weight_decay_exp
    config.last_layers_training_configs.append(LastLayerTrainingConfig(
        unfreeze_epoch=second_last_layer_unfreeze_epoch,
        lr=second_last_layer_lr,
        weight_decay=second_last_layer_weight_decay,
        use_train_mode=True,
    ))
    third_last_layer_unfreeze_epoch = trial.suggest_int("third_last_layer_unfreeze_epoch", second_last_layer_unfreeze_epoch, 15)
    third_last_layer_lr_exp = trial.suggest_float("third_last_layer_lr_exp", -5.0, -1.0)
    third_last_layer_weight_decay_exp = trial.suggest_float("third_last_layer_weight_decay_exp", -5.0, -1.0)
    third_last_layer_lr = 10.0 ** third_last_layer_lr_exp
    third_last_layer_weight_decay = 10.0 ** third_last_layer_weight_decay_exp
    config.last_layers_training_configs.append(LastLayerTrainingConfig(
        unfreeze_epoch=third_last_layer_unfreeze_epoch,
        lr=third_last_layer_lr,
        weight_decay=third_last_layer_weight_decay,
        use_train_mode=True,
    ))
    return train_and_get_max_val_accuracy(config)

In [44]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)

[I 2024-05-22 09:07:14,629] A new study created in memory with name: no-name-b3c4aa57-aa7e-428f-b0f1-3429b80fb546
Training model: 100%|██████████| 10/10 [00:01<00:00,  5.86it/s]


Train loss: 3.5890536356945426


[W 2024-05-22 09:07:18,863] Trial 0 failed with parameters: {'batch_size_exp': 5, 'last_layer_lr_exp': -2.2777062255790144, 'last_layer_weight_decay_exp': -2.9649743716673593, 'second_last_layer_unfreeze_epoch': 4, 'second_last_layer_lr_exp': -1.605894325812685, 'second_last_layer_weight_decay_exp': -2.2251746661797274, 'third_last_layer_unfreeze_epoch': 12, 'third_last_layer_lr_exp': -1.2197903756926198, 'third_last_layer_weight_decay_exp': -1.473720350516397} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\georg\anaconda3\envs\kth_deep_learning_project\lib\site-packages\optuna\study\_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\georg\AppData\Local\Temp\ipykernel_7672\3602300019.py", line 50, in objective
    return train_and_get_max_val_accuracy(config)
  File "C:\Users\georg\AppData\Local\Temp\ipykernel_7672\3791173904.py", line 20, in train_and_get_max_val_accuracy
    val_accuracy 

KeyboardInterrupt: 

In [25]:
study.best_trial

FrozenTrial(number=11, state=TrialState.COMPLETE, values=[0.9510869979858398], datetime_start=datetime.datetime(2024, 5, 21, 21, 52, 13, 728210), datetime_complete=datetime.datetime(2024, 5, 21, 22, 1, 12, 874608), params={'batch_size_exp': 5, 'last_layer_lr_exp': -2.8552450749872538, 'last_layer_weight_decay_exp': -3.262432218991849, 'second_last_layer_unfreeze_epoch': 1, 'second_last_layer_lr_exp': -2.635150903871032, 'second_last_layer_weight_decay_exp': -2.394793163781819, 'third_last_layer_unfreeze_epoch': 4, 'third_last_layer_lr_exp': -4.8965012088919515, 'third_last_layer_weight_decay_exp': -3.543387070729529}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'batch_size_exp': IntDistribution(high=7, log=False, low=3, step=1), 'last_layer_lr_exp': FloatDistribution(high=-2.0, log=False, low=-4.0, step=None), 'last_layer_weight_decay_exp': FloatDistribution(high=-2.0, log=False, low=-4.0, step=None), 'second_last_layer_unfreeze_epoch': IntDistribution(high=1

In [None]:
load_checkpoint(model, optimizer)
print('Checkpoint loaded')

test_accuracy = get_model_accuracy(model, test_data_loader)

print(f'Test accuracy: {100 * test_accuracy:.2f}%')