In [1]:
# !unzip drive/MyDrive/Pets-data/images.zip
# !cp -r drive/MyDrive/Pets-data/annotations .
# !ls

In [2]:
import os
import yaml
from typing import Tuple, List, Optional

import torch
import optuna
import numpy as np
import torchvision
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from pydantic import BaseModel
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Define constants

CONFIG_FILE_PATH = 'config_optuna.yaml'
IMAGE_COL_IDX = 0
CLASS_ID_COL_IDX = 1
SPECIES_COL_IDX = 2
POSSIBLE_NUM_CLASSES = {2, 37}

In [4]:
class AdamOptimizerConfig(BaseModel):
    lr: float
    weight_decay: float

In [5]:
class LastLayerTrainingConfig(BaseModel):
    unfreeze_epoch: int
    lr: float
    weight_decay: float
    use_train_mode: bool

In [6]:
class Config(BaseModel):
    device: str
    num_classes: int
    batch_size: int
    max_num_epochs: int
    patience: int
    # adam_optimizer_config: AdamOptimizerConfig
    # num_batch_norm_layers_to_train_params: int
    # num_batch_norm_layers_to_update_running_stats: int
    # train_earlier_layers_delay: int
    # n_hidden_layers_to_train: int
    last_layers_training_configs: List[LastLayerTrainingConfig]
    use_pseudo_labelling: int
    labelled_data_ratio: float
    T_1: int
    T_2: int
    ALPHA_F: float
    unlabelled_batch_size: int

In [7]:
# with open(CONFIG_FILE_PATH, encoding='utf-8') as f:
#     config_dict = yaml.load(f, Loader=yaml.FullLoader)

# config_test = Config.model_validate(config_dict)

# assert config_test.num_classes in POSSIBLE_NUM_CLASSES
# # assert 0 <= config.num_batch_norm_layers_to_train_params <= 36  # 36 batch norm layers in resnet34
# # assert 0 <= config.num_batch_norm_layers_to_update_running_stats <= 36  # 36 batch norm layers in resnet34

In [8]:
class ImageDataset(Dataset):
    def __init__(self, filenames: List[str], labels: List[int], use_augmentations: bool, device: str) -> None:
        self.filenames = filenames
        self.labels = labels
        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
                # transforms.CenterCrop(size=224),
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(dtype=torch.float),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                # torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.CenterCrop(size=256),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation(degrees=(-30, 30), expand=True),
                transforms.RandomResizedCrop(size=224),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )
        self.device = device

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(os.path.join('images', f'{self.filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(self.device), torch.tensor(label).to(self.device)

In [9]:
class ImageDatasetPseudoLabels(Dataset):
    def __init__(self, labelled_filenames: List[str], labels: List[int], unlabelled_filenames: List[str], use_augmentations: bool, device: str) -> None:
        self.labelled_filenames = labelled_filenames
        self.labels = labels
        self.unlabelled_filenames = unlabelled_filenames

        assert len(labelled_filenames) == len(labels), 'labels must have the same size as labelled_filenames'

        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
                # transforms.CenterCrop(size=224),
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(dtype=torch.float),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                # torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.CenterCrop(size=256),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation(degrees=(-30, 30), expand=True),
                transforms.RandomResizedCrop(size=224),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )
        self.device = device

    def __len__(self) -> int:
        return len(self.labelled_filenames) + len(self.unlabelled_filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if idx >= len(self.labelled_filenames): # Returning an unlabelled image
            image = Image.open(os.path.join('images', f'{self.unlabelled_filenames[idx - len(self.labelled_filenames)]}.jpg')).convert('RGB')
            transformed_img = self.transformation(image)
            return transformed_img.to(self.device), torch.tensor(-1).float().to(self.device)

        image = Image.open(os.path.join('images', f'{self.labelled_filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(self.device), torch.tensor(label).to(self.device)

In [10]:
class UnlabelledImageDataset(Dataset):
    def __init__(self, filenames: List[str], use_augmentations: bool, device: str) -> None:
        self.filenames = filenames
        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
                # transforms.CenterCrop(size=224),
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(dtype=torch.float),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                # torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.CenterCrop(size=256),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation(degrees=(-30, 30), expand=True),
                transforms.RandomResizedCrop(size=224),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )
        self.device = device

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(os.path.join('images', f'{self.filenames[idx]}.jpg')).convert('RGB')

        transformed_img = self.transformation(image)

        return transformed_img.to(self.device)

In [11]:
def get_image_names_and_labels(annotations_file_path: str, num_classes: int) -> Tuple[List[str], List[int]]:
    filenames: List[str] = []
    labels: List[int] = []

    with open(annotations_file_path, encoding='utf-8') as f:
        lines = f.readlines()

    label_col_idx = SPECIES_COL_IDX if num_classes == 2 else CLASS_ID_COL_IDX

    for line in lines:
        line_split = line.split()
        filenames.append(line_split[IMAGE_COL_IDX])
        labels.append(int(line_split[label_col_idx]) - 1)

    return filenames, labels

In [12]:
def get_pretrained_model_and_model_trainable_layers(num_classes: int, device: str) -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=num_classes)

    for param in model.parameters():
        param.requires_grad = False

    model_trainable_layers = [
        layer
        for layer in model.modules()
        if (not isinstance(layer, torchvision.models.resnet.ResNet) and
            not isinstance(layer, torchvision.models.resnet.BasicBlock) and
            not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and
            len(list(layer.parameters())) > 0)
    ]

    return model.to(device), model_trainable_layers

In [13]:
def get_model_accuracy(model: nn.Module, data_loader: DataLoader) -> float:
    correct_predictions_cnt = 0
    total_predictions_cnt = 0
    model.eval()
    with torch.no_grad():
        # for inputs, labels in tqdm(data_loader, desc='Computing accuracy'):
        for inputs, labels in data_loader:
            outputs = model(inputs)
            correct_predictions_cnt += (torch.argmax(outputs, axis=1) == labels).sum()
            total_predictions_cnt += len(outputs)
    return correct_predictions_cnt / total_predictions_cnt

In [14]:
def train_single_epoch(
        model: nn.Module,
        model_trainable_layers: List[nn.Module],
        train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        last_layers_training_configs: List[LastLayerTrainingConfig],
        epoch: int,
        ) -> float:
    model.eval()
    for layer_reverse_idx, layer_training_config in enumerate(last_layers_training_configs):
        layer = model_trainable_layers[-layer_reverse_idx - 1]
        if layer_training_config.unfreeze_epoch <= epoch:
            for param in layer.parameters():
                param.requires_grad = True
            if layer_training_config.use_train_mode:
                layer.train()
        if layer_reverse_idx and layer_training_config.unfreeze_epoch == epoch:
            optimizer.add_param_group({
                'params': layer.parameters(),
                'lr': layer_training_config.lr,
                'weight_decay': layer_training_config.weight_decay,
            })
    train_loss_sum = 0.0
    train_samples_cnt = 0
    # for inputs, labels in tqdm(train_data_loader, desc='Training model'):
    for inputs, labels in train_data_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)
    return train_loss_sum / train_samples_cnt

In [22]:
def infinite_batch_generator(loader: DataLoader):
    while True:
        for batch in loader:
            yield batch


def get_alpha(t: int, T_1: int, T_2: int, ALPHA_F: float) -> float:
    if t < T_1:
        raise RuntimeError
        # return 0.0
    elif T_1 <= t and t < T_2:
        return ((t - T_1) / (T_2 - T_1)) * ALPHA_F
    return ALPHA_F


def train_single_epoch_pseudo(
        model: nn.Module,
        model_trainable_layers: List[nn.Module],
        train_data_loader: DataLoader,
        unlabelled_train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        last_layers_training_configs: List[LastLayerTrainingConfig],
        epoch: int,
        T_1: int,
        T_2: int,
        ALPHA_F: float,
        ) -> float:
    model.eval()
    for layer_reverse_idx, layer_training_config in enumerate(last_layers_training_configs):
        layer = model_trainable_layers[-layer_reverse_idx - 1]
        if layer_training_config.unfreeze_epoch <= epoch:
            for param in layer.parameters():
                param.requires_grad = True
            if layer_training_config.use_train_mode:
                layer.train()
        if layer_reverse_idx and layer_training_config.unfreeze_epoch == epoch:
            optimizer.add_param_group({
                'params': layer.parameters(),
                'lr': layer_training_config.lr,
                'weight_decay': layer_training_config.weight_decay,
            })
    train_loss_sum = 0.0
    train_samples_cnt = 0

    unlabelled_generator = infinite_batch_generator(unlabelled_train_data_loader)

    # for labelled_inputs, labelled_labels in tqdm(train_data_loader, desc='Training model'):
    for labelled_inputs, labelled_labels in train_data_loader:
        optimizer.zero_grad()

        loss = 0.0

        labelled_outputs = model(labelled_inputs)
        labelled_loss = criterion(labelled_outputs, labelled_labels)
        if not np.isnan(labelled_loss.item()):
            loss += labelled_loss

        if epoch >= T_1:
            unlabelled_inputs = next(unlabelled_generator)
            unlabelled_outputs = model(unlabelled_inputs)
            unlabelled_labels = torch.argmax(unlabelled_outputs, axis=1)
            unlabelled_loss = criterion(unlabelled_outputs, unlabelled_labels)
            if not np.isnan(unlabelled_loss.item()):
                loss += get_alpha(epoch, T_1, T_2, ALPHA_F) * unlabelled_loss

        loss.backward()
        optimizer.step()
        # train_loss_sum += loss.item() * (len(unlabelled_outputs) + len(labelled_outputs))
        # train_samples_cnt += len(unlabelled_outputs) + len(labelled_outputs)
        train_loss_sum += loss.item() * len(labelled_outputs)
        train_samples_cnt += len(labelled_outputs)

    return train_loss_sum / train_samples_cnt

In [23]:
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> Tuple[nn.Module, torch.optim.Optimizer]:
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, 'checkpoint.pt')

In [24]:
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [25]:
def get_data_loaders(
        num_classes: int,
        use_pseudo_labelling: bool,
        labelled_data_ratio: float,
        batch_size: int,
        unlabelled_batch_size: int,
        device: str,
        ) -> Tuple[DataLoader, DataLoader, DataLoader, Optional[DataLoader]]:
    filenames_trainval, labels_trainval = get_image_names_and_labels('annotations/trainval.txt', num_classes=num_classes)
    filenames_test, labels_test = get_image_names_and_labels('annotations/test.txt', num_classes=num_classes)
    filenames_train, filenames_val, labels_train, labels_val = train_test_split(filenames_trainval, labels_trainval, test_size=0.2, stratify=labels_trainval)

    if use_pseudo_labelling:
        filenames_train, unlabelled_filenames_train, labels_train, _ = train_test_split(
            filenames_train, labels_train, train_size=labelled_data_ratio, stratify=labels_train
        )
        dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True, device=device)
        unlabelled_dataset_train = UnlabelledImageDataset(unlabelled_filenames_train, use_augmentations=True, device=device)
        unlabelled_train_data_loader = DataLoader(unlabelled_dataset_train, batch_size=unlabelled_batch_size, shuffle=True)
    else:
        dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True, device=device)
        unlabelled_train_data_loader = None

    dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True, device=device)
    dataset_val = ImageDataset(filenames_val, labels_val, use_augmentations=False, device=device)
    dataset_test = ImageDataset(filenames_test, labels_test, use_augmentations=False, device=device)

    train_data_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    val_data_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)
    test_data_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

    return train_data_loader, val_data_loader, test_data_loader, unlabelled_train_data_loader

In [26]:
def train_and_get_max_val_accuracy(config: Config) -> float:
    model, model_trainable_layers = get_pretrained_model_and_model_trainable_layers(config.num_classes, config.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model_trainable_layers[-1].parameters(),
        lr=config.last_layers_training_configs[0].lr,
        weight_decay=config.last_layers_training_configs[0].weight_decay,
    )

    train_data_loader, val_data_loader, test_data_loader, unlabelled_train_data_loader = get_data_loaders(
        config.num_classes,
        config.use_pseudo_labelling,
        config.labelled_data_ratio,
        config.batch_size,
        config.unlabelled_batch_size,
        config.device,
        
    )
    assert isinstance(unlabelled_train_data_loader, DataLoader)

    max_val_accuracy = float('-inf')
    argmax_epoch = -1

    # for epoch in range(config.max_num_epochs):
    for epoch in tqdm(list(range(config.max_num_epochs))):
        # print(f'Epoch #{epoch}:')
        train_loss = train_single_epoch_pseudo(
            model, model_trainable_layers, train_data_loader, unlabelled_train_data_loader, criterion, optimizer, config.last_layers_training_configs,
            epoch, config.T_1, config.T_2, config.ALPHA_F,
        )
        # print(f'Train loss: {train_loss}')
        val_accuracy = get_model_accuracy(model, val_data_loader)
        if val_accuracy > max_val_accuracy:
            # print(f'Validation accuracy: {100 * val_accuracy:.2f}% (new best)')
            max_val_accuracy = val_accuracy
            argmax_epoch = epoch
            # save_checkpoint(model, optimizer)
            # print('Checkpoint saved')
        else:
            # print(f'Validation accuracy: {100 * val_accuracy:.2f}% (worse than {100 * max_val_accuracy:.2f}% of epoch {argmax_epoch})')
            if epoch > argmax_epoch + config.patience:
                # print(f'Early stopping')
                break
        # # Start training earlier layers
        # if epoch == config.train_earlier_layers_delay and config.n_hidden_layers_to_train > 0:
        #     print (f'Making last {config.n_hidden_layers_to_train} hidden layers trainable')
        #     make_hidden_layers_trainable(model, config.n_hidden_layers_to_train)
        #     optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr / 100, weight_decay=config.adam_optimizer_config.weight_decay)
        # print()
    return max_val_accuracy

In [32]:
def objective(trial: optuna.Trial):
    batch_size_exp = trial.suggest_int("batch_size_exp", 3, 7)
    unlabelled_batch_size_exp = trial.suggest_int("unlabelled_batch_size_exp", 3, 7)
    batch_size = 2 ** batch_size_exp
    unlabelled_batch_size = 2 ** unlabelled_batch_size_exp
    T_1 = trial.suggest_int("T_1", 1, 30)
    T_2 = trial.suggest_int("T_2", T_1 + 1, 60)
    ALPHA_F_exp = trial.suggest_float("ALPHA_F_exp", -2.0, 2.0)
    ALPHA_F = 10.0 ** ALPHA_F_exp
    config = Config(
        device='cuda:0',
        num_classes=37,
        batch_size=batch_size,
        max_num_epochs=100,
        patience=10,
        last_layers_training_configs=[],
        # pseudo-labeling
        use_pseudo_labelling=1,
        labelled_data_ratio=0.1,
        T_1=T_1,  # 1_000
        T_2=T_2,  # 1_000
        ALPHA_F=ALPHA_F,  # 3
        unlabelled_batch_size=unlabelled_batch_size,
    )
    last_layer_lr_exp = trial.suggest_float("last_layer_lr_exp", -4.0, -2.0)
    last_layer_weight_decay_exp = trial.suggest_float("last_layer_weight_decay_exp", -4.0, -2.0)
    last_layer_lr = 10.0 ** last_layer_lr_exp
    last_layer_weight_decay = 10.0 ** last_layer_weight_decay_exp
    config.last_layers_training_configs.append(LastLayerTrainingConfig(
        unfreeze_epoch=0,
        lr=last_layer_lr,
        weight_decay=last_layer_weight_decay,
        use_train_mode=True,
    ))
    second_last_layer_unfreeze_epoch = trial.suggest_int("second_last_layer_unfreeze_epoch", 1, 10)
    second_last_layer_lr_exp = trial.suggest_float("second_last_layer_lr_exp", -5.0, -1.0)
    second_last_layer_weight_decay_exp = trial.suggest_float("second_last_layer_weight_decay_exp", -5.0, -1.0)
    second_last_layer_lr = 10.0 ** second_last_layer_lr_exp
    second_last_layer_weight_decay = 10.0 ** second_last_layer_weight_decay_exp
    config.last_layers_training_configs.append(LastLayerTrainingConfig(
        unfreeze_epoch=second_last_layer_unfreeze_epoch,
        lr=second_last_layer_lr,
        weight_decay=second_last_layer_weight_decay,
        use_train_mode=True,
    ))
    third_last_layer_unfreeze_epoch = trial.suggest_int("third_last_layer_unfreeze_epoch", second_last_layer_unfreeze_epoch, 15)
    third_last_layer_lr_exp = trial.suggest_float("third_last_layer_lr_exp", -5.0, -1.0)
    third_last_layer_weight_decay_exp = trial.suggest_float("third_last_layer_weight_decay_exp", -5.0, -1.0)
    third_last_layer_lr = 10.0 ** third_last_layer_lr_exp
    third_last_layer_weight_decay = 10.0 ** third_last_layer_weight_decay_exp
    config.last_layers_training_configs.append(LastLayerTrainingConfig(
        unfreeze_epoch=third_last_layer_unfreeze_epoch,
        lr=third_last_layer_lr,
        weight_decay=third_last_layer_weight_decay,
        use_train_mode=True,
    ))
    return train_and_get_max_val_accuracy(config)

In [33]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)

[I 2024-05-24 07:22:38,257] A new study created in memory with name: no-name-593a90bf-1d98-466a-909b-435f3c20a0b9
 13%|█▎        | 13/100 [00:58<06:33,  4.53s/it]
[I 2024-05-24 07:23:37,413] Trial 0 finished with value: 0.25679346919059753 and parameters: {'batch_size_exp': 7, 'unlabelled_batch_size_exp': 4, 'T_1': 2, 'T_2': 19, 'ALPHA_F_exp': 1.609556906321632, 'last_layer_lr_exp': -3.0121971398823826, 'last_layer_weight_decay_exp': -2.0848795855724744, 'second_last_layer_unfreeze_epoch': 8, 'second_last_layer_lr_exp': -1.629507688307899, 'second_last_layer_weight_decay_exp': -3.3468064319548483, 'third_last_layer_unfreeze_epoch': 15, 'third_last_layer_lr_exp': -2.927638952220023, 'third_last_layer_weight_decay_exp': -1.4510043759696747}. Best is trial 0 with value: 0.25679346919059753.
 28%|██▊       | 28/100 [02:27<06:18,  5.26s/it]
[I 2024-05-24 07:26:04,873] Trial 1 finished with value: 0.875 and parameters: {'batch_size_exp': 5, 'unlabelled_batch_size_exp': 6, 'T_1': 24, 'T_2': 4

KeyboardInterrupt: 

In [31]:
study.best_trial

FrozenTrial(number=31, state=TrialState.COMPLETE, values=[0.904891312122345], datetime_start=datetime.datetime(2024, 5, 24, 3, 17, 7, 298202), datetime_complete=datetime.datetime(2024, 5, 24, 3, 21, 8, 100381), params={'batch_size_exp': 6, 'unlabelled_batch_size_exp': 7, 'T_1': 22, 'T_2': 30, 'ALPHA_F_exp': -1.9186627229256172, 'last_layer_lr_exp': -3.437742747579238, 'last_layer_weight_decay_exp': -3.8574542923939004, 'second_last_layer_unfreeze_epoch': 6, 'second_last_layer_lr_exp': -4.205220815004278, 'second_last_layer_weight_decay_exp': -3.818874540626863, 'third_last_layer_unfreeze_epoch': 8, 'third_last_layer_lr_exp': -3.03043588451279, 'third_last_layer_weight_decay_exp': -4.190767281932155}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'batch_size_exp': IntDistribution(high=7, log=False, low=3, step=1), 'unlabelled_batch_size_exp': IntDistribution(high=7, log=False, low=3, step=1), 'T_1': IntDistribution(high=30, log=False, low=1, step=1), 'T_2': IntD

In [23]:
load_checkpoint(model, optimizer)
print('Checkpoint loaded')

test_accuracy = get_model_accuracy(model, test_data_loader)

print(f'Test accuracy: {100 * test_accuracy:.2f}%')

NameError: name 'model' is not defined