# Projet Fixmatch

In [1]:
# !pip install torchview torchsummary torchvision kornia torchmetrics matplotlib tqdm path graphviz opencv-python scikit-learn optuna

In [2]:
# deep learning
import torch
import torch.nn as nn
from torch.distributions.transforms import LowerCholeskyTransform
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, Dataset

# vizualisation
import torchsummary

# transforms
import torchvision.transforms as T
import kornia.augmentation as K
from kornia.enhance import normalize
from torchvision.transforms import RandAugment

# metrics
from torchmetrics import Accuracy

# torchvision
import torchvision
import torchvision.transforms as transforms

# plotting
import matplotlib.pyplot as plt
from torchview import draw_graph

from IPython.display import display
from IPython.core.display import SVG, HTML

from tqdm.auto import tqdm

# typing
from typing import Callable

from utils import plot_images, plot_transform
from model import ConvNN, display_model

# os
import os
import path

import random
import numpy as np 

# transformations
# import transform as T
# from randaugment import RandomAugment

# typing
from typing import Callable, List, Tuple

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics.pairwise import cosine_similarity

%load_ext autoreload
%autoreload 2

In [3]:
DEFAULT_RANDOM_SEED = 2021

def seedBasic(seed=DEFAULT_RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
# torch random seed
import torch
def seedTorch(seed=DEFAULT_RANDOM_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
# basic + tensorflow + torch 
def seedEverything(seed=DEFAULT_RANDOM_SEED):
    seedBasic(seed)
    seedTorch(seed)

In [4]:
# Set device
if ((int(torch.__version__.split(".")[0]) >= 2) or (int(torch.__version__.split(".")[1]) >= 13)) and torch.has_mps:
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [5]:
IMG_SHAPE = (3, 32, 32)
# See Table 4
TAU = 0.9 #! 0.95 in the paper
LAMBDA_U = 3
MU = 4
BATCH_SIZE = 64
LR = 0.03
BETA = 0.9
WEIGHT_DECAY = 0.001
BETA_DENSITY = 1

In [6]:
class ConvNN(nn.Module):
    """
    Simple CNN for CIFAR10
    """
    
    def __init__(self):
        super().__init__()
        self.conv_32 = nn.Conv2d(3, 32, kernel_size=3, padding='same')
        self.conv_64 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.conv_96 = nn.Conv2d(64, 96, kernel_size=3, padding='same')
        self.conv_128 = nn.Conv2d(96, 128, kernel_size=3, padding='same')
        self.fc_512 = nn.Linear(512, 512)
        self.fc_10 = nn.Linear(512, 10)
        self.max_pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU(inplace=True)
        self.flatten = nn.Flatten()
        # self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_32(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_64(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_96(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_128(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.flatten(x)
        x = self.fc_512(x)
        x = self.relu(x)
        x = self.fc_10(x)
        # x = self.softmax(x)

        return x

In [7]:
def compute_mean_std(trainLoader) -> Tuple[List[float], List[float]]:
    # initialize the list of means and stds
    mean, std = torch.zeros(3), torch.zeros(3)

    # iterate over the dataset and compute the sum of each channel
    for images, _ in trainLoader:
        mean+= torch.mean(images, dim=[0,2,3])
        std+= torch.std(images, dim=[0,2,3])
    
    # compute the mean and std
    mean = mean/len(trainLoader)
    std = std/len(trainLoader)

    return mean, std

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

if not os.path.exists('./data/mean.pt'):
    mean, std = compute_mean_std(trainloader)
    torch.save(mean, 'data/mean.pt')
    torch.save(std, 'data/std.pt')
else:
    mean, std = torch.load('./data/mean.pt'), torch.load('./data/std.pt')

# to numpy
mean, std = mean.numpy(), std.numpy()

print(f"mean: {mean}, std: {std}")


testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
mean: [0.4913966  0.48215377 0.44651437], std: [0.246344   0.24280126 0.26067406]
Files already downloaded and verified


In [8]:
torch_models = 'torch_models' 
if not os.path.exists(torch_models):
    os.makedirs(torch_models)

## IV. Semi-Supervised Learning: Fixmatch with Active Learning

### IV.1 Fixmatch on 10% train data with Active Learning

In [9]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 50
SUBSET_PROP = 0.01
K_SAMPLES = 64

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# indices of labeled data
labeled_indices = trainset_sup.indices

# indices of unlabeled data
unlabeled_indices = trainset_unsup.indices



In [10]:
# transformations
weak_transform = K.ImageSequential(
    K.RandomHorizontalFlip(p=0.50), 
    K.RandomAffine(degrees=0, translate=(0.125, 0.125)),
)

strong_transform = K.ImageSequential(
    K.auto.RandAugment(n=2, m=10),
)

In [11]:
def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

In [12]:
model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)

In [13]:
def information_density(
    model: ConvNN, 
    unlabeled_inputs: torch.Tensor,
    k_samp: int) -> Tuple[torch.Tensor, torch.Tensor]:

    inputs_norm = normalize(data=unlabeled_inputs, mean=mean, std=std)

    logits = model(inputs_norm)
    qb = torch.softmax(logits, dim=1)
    qb1 = torch.topk(qb, k=2, dim=1)[0][:, 0]
    qb2 = torch.topk(qb, k=2, dim=1)[0][:, 1]

    uncertainty = 1 - (qb1 - qb2)

    input1 = inputs_norm.view(inputs_norm.shape[0], -1)

    cos_sim = cosine_similarity(input1.cpu().numpy())
    information_density = (torch.sum(torch.tensor(cos_sim), dim=1) / inputs_norm.shape[0]) ** BETA_DENSITY

    output = uncertainty * information_density.to(device)

    inf_dens, idx = torch.topk(output, k=k_samp, dim=0)
    
    return inf_dens, idx.cpu().numpy()

# Create a new labeled dataset using active learning
def create_labeled_dataset_active_learning(dataset, selected_indices):
    labeled_dataset = torch.utils.data.Subset(dataset, selected_indices)
    return labeled_dataset

In [14]:
print("Start training")

current_prop = SUBSET_PROP
target_prop = 0.05
max_iter = 300
j = 0

# Define the cosine learning rate decay function
# lr_lambda = lambda step: LR * torch.cos(torch.tensor((7 * torch.pi * (step)) / (16 * max_iter//3))) * 100 / 3

# cosine annealing scheduler
# lr_lambda = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iter, eta_min=0, last_epoch=-1)

# Create a learning rate scheduler with the cosine decay function
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iter//2, eta_min=0, last_epoch=-1)


train_losses = []
test_losses = []
added_samp = 0

while j <= max_iter and current_prop <= target_prop:
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    # max_confidence = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {j: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, unlabeled_labels = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        

        # Zero the parameter gradients
        optimizer.zero_grad()

        # normalize labeled and unlabeled inputs
        labeled_inputs_norm = normalize(data=labeled_inputs, mean=mean, std=std)
        unlabeled_inputs_norm = normalize(data=unlabeled_inputs, mean=mean, std=std)

        # Forward pass 
        labeled_outputs_norm = model(labeled_inputs_norm)
        unlabeled_outputs_norm = model(unlabeled_inputs_norm)

        # Compute moving average of labeled and unlabeled predictions
        moving_avg_pred_labeled = (i * moving_avg_pred_labeled + labeled_outputs_norm.shape[0]) / (i + 1)
        moving_avg_pred_unlabeled = (i * moving_avg_pred_unlabeled + unlabeled_outputs_norm.shape[0]) / (i + 1)

        # ratio
        ratio = moving_avg_pred_labeled / moving_avg_pred_unlabeled

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # normalize
        weak_labeled_inputs = normalize(data=weak_labeled_inputs, mean=mean, std=std)
        weak_unlabeled_inputs = normalize(data=weak_unlabeled_inputs, mean=mean, std=std)
        strong_unlabeled_inputs = normalize(data=strong_unlabeled_inputs, mean=mean, std=std)

        # prediction on weak augmented unlabeled data
        qb = model(weak_unlabeled_inputs)
        qb = torch.softmax(qb, dim=1)
        qb_norm = qb * ratio

        # normalize
        qb_tilde = qb_norm / torch.sum(qb_norm, dim=1, keepdim=True)

        # compute mask
        max_qb_tilde, qb_tilde_hat = torch.max(qb_tilde, dim=1)
        idx = max_qb_tilde > TAU

        # pseudo labels
        pseudo_labels = qb_tilde_hat[idx]

        # mask strong augmented unlabeled data
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)
            # outputs = torch.softmax(outputs, dim=1)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)
            # labeled_outputs = torch.softmax(labeled_outputs, dim=1)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        
        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled 

        # update progress bar
        pbar.set_postfix({
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg confidence": torch.mean(max_qb_tilde).item(),
            "n_unlabeled": running_n_unlabeled,
            "current_prop": current_prop,
            "lr": optimizer.param_groups[0]['lr']
        })

    if j < max_iter and j%5 == 0 and current_prop < target_prop:
        # compute information density
        scores_density, idx_density = information_density(model, unlabeled_inputs, K_SAMPLES)

        # add selected indices to labeled indices
        labeled_indices = np.concatenate((labeled_indices, idx_density))

        # unique
        labeled_indices = np.unique(labeled_indices)

        # create new dataset from labeled indices
        trainset_sup = create_labeled_dataset_active_learning(trainset, labeled_indices)

        # update dataloader
        labeled_dataloader = torch.utils.data.DataLoader(trainset_sup, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

        current_prop = len(trainset_sup) / len(trainset)

        print(f'Information density: {scores_density.mean()}')
        print(f'Density indices: {idx_density}')

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)

    # increment iteration
    j += 1

Start training


Epoch     0:   0%|          | 0/8 [00:00<?, ?batch/s]

Information density: 0.06339705735445023
Density indices: [181 127 183 219 141 208 214  52 234  97  85 187 179 218 227 199   3 130
 134 210  61 249  18 173 113 221  95   7  63 206 251 120  12  32  13   9
  79 162 229  47  78  96  77 226 166  38 184 145  90 246 158  71  41 118
 253  69   4 198   5  48 239 153  42 131]
Test Accuracy: 10.0%


Epoch     1:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 10.07%


Epoch     2:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 10.68%


Epoch     3:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 17.15%


Epoch     4:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 20.4%


Epoch     5:   0%|          | 0/9 [00:00<?, ?batch/s]

Information density: 0.05281660705804825
Density indices: [ 62 149  27 115 230 251 116   1 128 253  95 191 119 202 212 132  57  79
 120 137  91 181 101 234  78  26  23 235 147 142 102  13 192 178 186 206
  59  53  84  81 172 126 245 153  36 248  37  30  22  68   9 175 242 135
 106 250 184  35  80 133 237 200 177   5]
Test Accuracy: 21.45%


Epoch     6:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 24.56%


Epoch     7:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 24.16%


Epoch     8:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 26.99%


Epoch     9:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 31.97%


Epoch    10:   0%|          | 0/10 [00:00<?, ?batch/s]

Information density: 0.06354041397571564
Density indices: [ 69  97 173 123 241  82 214 189 116 122 212 127 181 236 102 218  42  50
 130  33 188  47 197 135  60 111 109 166  94 252  67 161 114 152 229  52
  88 105  24 153  58 121  23 209 140 120 215 156   7 133  37 217  49  35
 186  27 175 142 204 223 164 185  22 238]
Test Accuracy: 29.43%


Epoch    11:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 30.47%


Epoch    12:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 29.46%


Epoch    13:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 35.44%


Epoch    14:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 31.08%


Epoch    15:   0%|          | 0/11 [00:00<?, ?batch/s]

Information density: 0.04950621351599693
Density indices: [235 167  58 125  60 126  93 103  20  15 186 147 201 156  11 244  55 152
  59  24 104 191 185 192  46 107  18  47  49 142 205 157   2  70 241  98
 137  26  63  51 237  99 230  10 115 209 243 114 108  13  85 252   4  76
 163 242  68 198  94  39  29  52 150  34]
Test Accuracy: 31.53%


Epoch    16:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 35.47%


Epoch    17:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 34.06%


Epoch    18:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 34.22%


Epoch    19:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 36.44%


Epoch    20:   0%|          | 0/11 [00:00<?, ?batch/s]

Information density: 0.035349950194358826
Density indices: [138  29  81   8 145 187  23 225 158 233  77  82 213  86 211 209 186  92
  70 143  71  57  13 134  98 112  18   1  63  15 109  85 115  31 194 236
 110 150  60  20 157 244  79 116 117  59 142 107  69  16 247 172 237 101
 169  74 118 204  53 105  52 214 146 139]
Test Accuracy: 36.91%


Epoch    21:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 37.54%


Epoch    22:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 37.81%


Epoch    23:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 39.92%


Epoch    24:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 40.94%


Epoch    25:   0%|          | 0/11 [00:00<?, ?batch/s]

Information density: 0.036071695387363434
Density indices: [166   9 114 241  22  13 155 142 169 201  57  26 186 173  70 179  17 172
 239 227 126 100 123 129 195 197  53 248 152  14  88  49 103 235 243 157
  18 192 203  87 190 187 181  21 207  80 205 218 111   0  19 222  40 160
 202  60 110  36  28 119 188  83 161 185]
Test Accuracy: 41.14%


Epoch    26:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 40.31%


Epoch    27:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 38.46%


Epoch    28:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 35.33%


Epoch    29:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 37.53%


Epoch    30:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.04484696686267853
Density indices: [ 16 155  29 246  43 124  41  91 127  62  36   9  92 174 148 150 222  46
  18  56 125   5 171 130 180 205 207  25  15 249 192  19  33 138 238 175
 228  45 172 204  54  28 209 170 166 118 131  83 137 194  75 142 143 109
 188  47 132 159 244  24 250 103 139  50]
Test Accuracy: 41.96%


Epoch    31:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 40.46%


Epoch    32:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 41.86%


Epoch    33:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 42.74%


Epoch    34:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 41.68%


Epoch    35:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.04516717046499252
Density indices: [102  96 133  43 221 208 121 227 220 174  47 183 137  61 107 123 157 138
 224  28 139  68 122 175 177  94 141 231 120  98 226  89   0 225  63 232
  30 194 106   6  11  23 114 235 152 251  81 197  45  69  57 193  14  90
  39 117 199  32 201 233 167   5 189 135]
Test Accuracy: 41.17%


Epoch    36:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 41.87%


Epoch    37:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 40.18%


Epoch    38:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 43.31%


Epoch    39:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 43.82%


Epoch    40:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.04933123290538788
Density indices: [248 155 190 252  79 142 206 187   1 214  91  19 192  22  46 138  86 185
  92  42  44 136 188  80 158 249  60 117 164  88  37   3  63  47  52 240
 203  15 123 107 103  45 172 106  32 100  35 147 254 175 128  61  95 134
  81  34 225  58 112 127 193 156  17 216]
Test Accuracy: 44.21%


Epoch    41:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 38.21%


Epoch    42:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 42.73%


Epoch    43:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 40.1%


Epoch    44:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 43.74%


Epoch    45:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.058108314871788025
Density indices: [ 28  98 220 143 236  81  42 116 166 206 209 252 241  60 201  89  99 149
  68 217 138 122  70 183  44 186 123 198 126 242  61 152  71 229 216 194
 218 124 184  40 173 208 247 160  69 202 162 253 238  95 101  94  78 254
 148 246 158  80 180  25 109 240 245  53]
Test Accuracy: 39.21%


Epoch    46:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 42.08%


Epoch    47:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.49%


Epoch    48:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 43.37%


Epoch    49:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.51%


Epoch    50:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.028300132602453232
Density indices: [226  67 195 170 106  59 193  46  58 112  60 254 198  92  40 194 244 108
  33 162 171 242  30  71  45 131  28  48 200  95 145  19  66 223 149 203
 103 100 230  37 231  65 178  91  49 118 147 201   4 102   3  62  87  43
 165 130  69 141 120  85 169 192 152  93]
Test Accuracy: 27.92%


Epoch    51:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 37.29%


Epoch    52:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 41.48%


Epoch    53:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 41.72%


Epoch    54:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 40.05%


Epoch    55:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.029531415551900864
Density indices: [ 13  86  14  67  90 113  57   4  31  76 131  47 239 165 207 186 107 245
  19 242   2 172 224 249  77 147 137 173   3 251  74  30 187 185 164 132
 157 183   0 243  29  64 227 248 221  43 218 228  38 110 112  55 194  10
 176 114  40 134 174  87  83 206 196 105]
Test Accuracy: 35.35%


Epoch    56:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 41.56%


Epoch    57:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.61%


Epoch    58:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.94%


Epoch    59:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.55%


Epoch    60:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03860954940319061
Density indices: [157  21 136 127 102  30   0  97 220 132  59 255   1 246  80  24 230 252
 119  58 139  51 163  11  68 162  79 180  27 206  73  33 236  34 122  31
  65  91  75 156  63 159 210  85 224  64 229 144 149 141  76 116 216  84
 231 219  87  14 143 146  25 140 254 182]
Test Accuracy: 43.24%


Epoch    61:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.16%


Epoch    62:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.09%


Epoch    63:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.59%


Epoch    64:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.39%


Epoch    65:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03801226615905762
Density indices: [213  91 134 166 210  97 241 190 130  95 165  73  86  16  46 199  84 154
  54 129 219   7 104  42 248 232  32 176  92 101  50 113   5 163  39 177
 234  94 209 127 157 153 211 239  98  10  27 188  59 110 246  13 126 124
 255 214 221 171 185 148 197 253 205 217]
Test Accuracy: 45.82%


Epoch    66:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.2%


Epoch    67:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.68%


Epoch    68:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 47.18%


Epoch    69:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 42.73%


Epoch    70:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.02591361477971077
Density indices: [182   9 143 216  24 228  91  82 209 240 170 165 224 204 239 234 171  98
 232 255  95  80 120 133 152 105 162  18 168   7 154 235   8  46  58 214
 115   3  36  86 188  43  12 206  13  44 198  96 173 130   4 179 200 166
 123 217 107 103 174  63  21 155 189 164]
Test Accuracy: 45.98%


Epoch    71:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 43.25%


Epoch    72:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.65%


Epoch    73:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.99%


Epoch    74:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.54%


Epoch    75:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.035929806530475616
Density indices: [130 247 217 138 133  92  21  65 152 233 150 167 122 236 166 235 155 157
  16  81  93 245 165  84 196 204  11 136 143  26  49 237 104 171   9 221
 153 193 148 113 141 100  30  23  61  76 216 118 218  70 197 116  90  46
  98 139  95  78 180  38 249 252 226 227]
Test Accuracy: 45.82%


Epoch    76:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.12%


Epoch    77:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 47.01%


Epoch    78:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.38%


Epoch    79:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.39%


Epoch    80:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.025854067876935005
Density indices: [123 103  59 227  48  78 100  86  29 169 237 157 131 249  63 243 148 170
  90 135 164  36 182 203 231  94 158 216   7 149  31 207  37 239  45 101
 108 186 240 138  19 180 189 156 209 232 144 130  27 120 246 150 245 179
  12 122 229 205  98 133 248 145 218 125]
Test Accuracy: 47.12%


Epoch    81:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.98%


Epoch    82:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 47.46%


Epoch    83:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.58%


Epoch    84:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 42.82%


Epoch    85:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.042322903871536255
Density indices: [  6 231 102 148   4  75 120 180  16  85  17  22 111 240 195 212 171 253
  94 219  59 241  60 174  53 225  82 115 246 140 162 126  32 166 192 127
  25 118  50 226 108  78 158  52 160 129 163 139 194 134  99  23 150  89
 167  67 141  55   0 222  12 237  39 221]
Test Accuracy: 47.51%


Epoch    86:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.11%


Epoch    87:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 47.78%


Epoch    88:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.71%


Epoch    89:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.64%


Epoch    90:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.023200945928692818
Density indices: [211 236 203  17 124   8  74 117  55 156 169  85  24  84 187 226 127 104
 114 111  26 116  50 253   6 223 210 153 245  48 185 172 123 142 221  79
 165 115  54 113  81  27 248 144 243 220  94  33 175  61 206  92 118  25
  76 103 105 139 125  31  35  96 229 207]
Test Accuracy: 48.05%


Epoch    91:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.27%


Epoch    92:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.24%


Epoch    93:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.65%


Epoch    94:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.72%


Epoch    95:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.038747139275074005
Density indices: [210 236 152 237 202  41 220 120 149  98 248 163  11 140 101 179 186  67
  44 119 154 219 139 217  46 190 113  81 178  31   9 241  54  82  53  78
 136 216 123 207 230  10 215  65 188 156 159 157 168 184 182  85  24  19
  59  88 198 148 174 212  39  56  17 160]
Test Accuracy: 48.36%


Epoch    96:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.18%


Epoch    97:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.9%


Epoch    98:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.84%


Epoch    99:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.95%


Epoch   100:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.028457453474402428
Density indices: [110 128  36 226 215 234 213 250  12  94 216  24  98 116 138  97 155 180
 114  90 179 145 160  15 120 167 109 188  70   0 157  87 184 251  81 225
 137  31 246 107  92 174 247 201 214 144  71 112 131 205  60  14  32  51
  85 211   8 135 249 236 197  27 149  68]
Test Accuracy: 48.97%


Epoch   101:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.62%


Epoch   102:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.51%


Epoch   103:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.55%


Epoch   104:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.07%


Epoch   105:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.031025022268295288
Density indices: [  2 131 223  17  22  19 188  90  81 180  31   7 135  42   3 115 114  67
 249 141 182 152 142 139 157 252  49 128  99 168 202  18  21  37 140  13
  56 211  83  55 221 137 197 163 125 101  66 155  75 106  86 116 254 172
  24 195  26  53 143 174 181  27  63  76]
Test Accuracy: 49.19%


Epoch   106:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 43.71%


Epoch   107:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 43.65%


Epoch   108:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.97%


Epoch   109:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.4%


Epoch   110:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.04665159434080124
Density indices: [212 101 191 107 237  27 170 248 211 173  90  40  83 127 251 222  26 239
  65 229  16  15   3 190 226 178 238  38  28  60  79  87 163 124 151  43
 235  10  97 228  67  12 145 143  76 253  82 183  39 176 126  49  88  20
 108 158 120 172 129   2  54  24 243 247]
Test Accuracy: 47.19%


Epoch   111:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.96%


Epoch   112:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 46.95%


Epoch   113:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.54%


Epoch   114:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 45.83%


Epoch   115:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.022191237658262253
Density indices: [ 50 225 197 105  75  11  89 151  53 155 158 173   8  74 213 253 161 250
 117 163 116 159 229 232 218  17  36  48  22 169 104 235 127 204  16   7
  95  20 132 224 233 245  65  18 191  42  92  21 210  33 124 251 178   3
 208 130 247 238 123 192 187 110 179 207]
Test Accuracy: 49.5%


Epoch   116:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.63%


Epoch   117:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.91%


Epoch   118:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 47.39%


Epoch   119:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 44.34%


Epoch   120:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.02932329662144184
Density indices: [200 123   9  59 110 178 203  29 218 129 161 109 119  78  55  54  17  33
 181 252 108   3 151  70  98 115  74  85 120 167 140 230 170  66 149  73
 221 247  40  46 128   1 148 239 209  21 111  31 155  20  48   7 118  27
 234 228 246 195 138 127  65 214   5 158]
Test Accuracy: 49.34%


Epoch   121:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.67%


Epoch   122:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.19%


Epoch   123:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.18%


Epoch   124:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.45%


Epoch   125:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03147299587726593
Density indices: [ 51 194 227   4 178 219  68  44 243 140 136 100 251  19  99 113 252 236
 233 123  82 228  56 134 120 179  43  54 253  20 240 216 160 164 231 153
  60  15 187 143 145 105  41 173  61  34 244 212 138 203 126 237  23  38
 154 122 162  95 131 193 149  55 209 103]
Test Accuracy: 46.16%


Epoch   126:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.73%


Epoch   127:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.29%


Epoch   128:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.81%


Epoch   129:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.32%


Epoch   130:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03728277236223221
Density indices: [110 216  75  28 187 162  50 136  43 117 219 207   2 152 160 218  71 164
  22 126  15 124  46 202 232  10 130  92  98 177 221  60  56 210  38 239
 157 188 108 140  90  57 237 101 178  62 142 176  79 109  68  20  61 115
 243 153 132 241  95  25  24 139 171 215]
Test Accuracy: 51.53%


Epoch   131:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.31%


Epoch   132:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.95%


Epoch   133:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.8%


Epoch   134:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.45%


Epoch   135:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.024505916982889175
Density indices: [178 234 114  74  21 101  34 119 189   6 160 122 236 238 252  50  33  88
  56 253 156 146  52 232  23  49 220  85 173 103 121  18 250 130 111   2
 225 104  68  38 141 241  10  72 233 210 222 244 145 240 181  14 166  44
  22 108   1  57  82  80 191 204 177  24]
Test Accuracy: 47.05%


Epoch   136:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.46%


Epoch   137:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.56%


Epoch   138:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.19%


Epoch   139:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.95%


Epoch   140:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03808615356683731
Density indices: [194  62  24  84 199 124 248 212 222 150 152  63 251  43 206  87  86 100
  80  88 242  53   1  18 127  52 209  64 149 202 175 178  34 131 168 189
 101 244  39  14 217  77 223 233  27 249 159  66 215  58 235 169 207  45
  17 198 190  98 148 145 165 108 250  70]
Test Accuracy: 49.16%


Epoch   141:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.45%


Epoch   142:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.04%


Epoch   143:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.21%


Epoch   144:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.53%


Epoch   145:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.037917766720056534
Density indices: [128 141   6 114 191  50 252 199 250 216 180  77  67 233  15 236 109  73
  96 246 184   0 197 163 162 207  25  78  32 183 249 125 186  38  27 187
 156 210  21 228  91  19 136 111 232 204  54  34 212  85 200 142   5 133
 161 108 242   3   1 167 150 220 177 217]
Test Accuracy: 50.96%


Epoch   146:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.56%


Epoch   147:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.22%


Epoch   148:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.4%


Epoch   149:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.2%


Epoch   150:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.023118505254387856
Density indices: [215  78 207 182 175 190 120 227 249  47  62 150  71  27 248 144 246  35
 200  60 126 118 197  10 244  17 165 127   6 243 232 206 229 103 239 176
 192 149 225 209  52 208 255  50 186  73  30 147 101 185 141 233  40  48
  98 169 114  32  15 137 110 177  29  13]
Test Accuracy: 51.48%


Epoch   151:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.0%


Epoch   152:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 49.5%


Epoch   153:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.83%


Epoch   154:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.48%


Epoch   155:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.02615845948457718
Density indices: [ 89 245  48  54 202 132  68 235 121  86 198  81  28  27 170  69  95 164
 211 169 153 144 146 140  21  18 223  47  88 225  83   6 126 177  35  15
 239 166 254 229  63 240 168 217  94  40 129  25 218 101 215  84 147  23
  26  45   0  96 171 176 151 123  11 111]
Test Accuracy: 53.38%


Epoch   156:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.05%


Epoch   157:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.56%


Epoch   158:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.54%


Epoch   159:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.85%


Epoch   160:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.029283832758665085
Density indices: [188 210 207 136 134 212  66 111 112   8 179 171 168 194 131  51 140  74
  22  86  57  75  30 247  18 217 157  28  97 230 110 197  48 180 147 156
  38  55 185  17  42 209  56 100  76  53 170 138  80 118 242 145 226  58
 221  68 231  49 220 208  34 201 240 123]
Test Accuracy: 51.62%


Epoch   161:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.21%


Epoch   162:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.89%


Epoch   163:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.43%


Epoch   164:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.99%


Epoch   165:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.01718064397573471
Density indices: [198  47 115 229 246  78 136  49 111   7 236 127  30 234 118 207   9  55
 135 175 123 162 122 125 143 113 212   1  26  11 180 221 196  35 251 110
 124 210  32  92  48  95  97   8 151 166  12 238  40  15 159 188 213 142
 132   4 171   6 119 150  42 193 197 203]
Test Accuracy: 48.95%


Epoch   166:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.57%


Epoch   167:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 48.94%


Epoch   168:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.04%


Epoch   169:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.52%


Epoch   170:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03380497545003891
Density indices: [226   2  19  87 174  94  44 246 207   3 110 202 177  65  82  15 237  35
 188 122 206  33 155 219 146  10 105 158 150 205 180  30 193 210  51 159
 250 189 108 160 192 195  84 103 130  93 230  69 120  56 136 235  26 168
 101 197 151 249  27  17 106 186 109 185]
Test Accuracy: 53.93%


Epoch   171:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.62%


Epoch   172:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.27%


Epoch   173:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.08%


Epoch   174:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.55%


Epoch   175:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.02798948809504509
Density indices: [ 78 250  29 100 253 204  30  72  38   9  77  33  86 191  88 244 120  70
  92  52  23 222 136  26 140 109 194   0  48 118 122  49 154 132 245  94
 101  61  14 219 207 230 201 141 171  53 110 111 247 224  18 242 196 139
 243 190 206 134   3  42  82  66  50 218]
Test Accuracy: 49.17%


Epoch   176:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.29%


Epoch   177:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.12%


Epoch   178:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.97%


Epoch   179:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.2%


Epoch   180:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.024287380278110504
Density indices: [ 78  26  85 230 223  75 120 134 165  20  83 153 178   6 240   5  92 102
 189  69 107 206  23 131  57 227 253  47 212 106  32 141 216 114 188 101
 176 121 226  86  81 237 246  98  70 233 150  79  34 103 140 208 144  80
  82 177  18   9  76 169  89 149  31 133]
Test Accuracy: 54.47%


Epoch   181:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.97%


Epoch   182:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.33%


Epoch   183:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.14%


Epoch   184:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.4%


Epoch   185:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03410399332642555
Density indices: [ 97  21 178  80 109 133  58 235  68  51  44  20 241 198 182  19 210 153
  92 194  84  69 158  57  87 129 147 242 225  41  14 108 125  23  12 190
 138  72 214  67 240  90 131 171 150  22 197  26 220 170 232  25 162 168
  99 254  48 212 123 127  32 173 219 201]
Test Accuracy: 53.55%


Epoch   186:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.27%


Epoch   187:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.66%


Epoch   188:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.98%


Epoch   189:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.82%


Epoch   190:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03065178170800209
Density indices: [218  30  34 169   1   5 146 113  33 221  71  11 244  99 154 196 180 225
 138 240   8 201  57 243  64 176 234 251 226 222  46  73 115  70  97  39
 207  75 195 192  91 242 155 174  88  22 123  55 254  42 216  43 110 235
   4 224 198  65 135 175  48 253 105 217]
Test Accuracy: 54.81%


Epoch   191:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.6%


Epoch   192:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.31%


Epoch   193:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 50.94%


Epoch   194:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.78%


Epoch   195:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.022588517516851425
Density indices: [228 166 213 158  10 124  65 195 126 251  74  38 204 224 112  25 233 184
  92  83  99 190  28 170  55  20 200  76  81  71 134 171   8 173   9 211
  62 174  69 156 189 152 100 108 231 101 168 249 109 151  22 161  91 137
 143  80  14 214 179 123 116 169  66  68]
Test Accuracy: 52.92%


Epoch   196:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.93%


Epoch   197:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 51.59%


Epoch   198:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.66%


Epoch   199:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.8%


Epoch   200:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.025568991899490356
Density indices: [100 150 248 140  94  93  49 117  67  19 110 189 246 157 124 185  42 147
  16  11 227  71 102 217  92 219  95 247  17  68  20 239 251 123  97 130
 252 174 222   0 159 241 235  45  12 155 242  41 122  90 129 142  47 205
 226 192 234 138   3  65  61 199 116 233]
Test Accuracy: 55.24%


Epoch   201:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.12%


Epoch   202:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.6%


Epoch   203:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.35%


Epoch   204:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.49%


Epoch   205:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.01869536004960537
Density indices: [218  75 205  17 220 159 189  49 130 120 173 228 169 134  50 151  23 131
  43  59 176 251  61 207 239 178  12  51  27  68   2 233  74  31  85 171
  29  97 153  22 115 156  56 174  35  72  77  30 129  83  79 213 237  60
 172 196 255 222 186 144 143 157   6 235]
Test Accuracy: 53.51%


Epoch   206:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.44%


Epoch   207:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.21%


Epoch   208:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.1%


Epoch   209:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.2%


Epoch   210:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.0312618613243103
Density indices: [170  91  92  44  24 129 225  42 213  29 147  20 199  16  97 117 112 241
 251 118  61 185  60 140 195 142 216 148  88  80  30 239  43 218  82 111
  36  27  48 230 126  21  45  65 133 101 110 186 125 188 155  14  83 104
  79 248 187 154 149 135 108  22 103 246]
Test Accuracy: 55.3%


Epoch   211:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.55%


Epoch   212:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.85%


Epoch   213:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.17%


Epoch   214:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.67%


Epoch   215:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.0330323725938797
Density indices: [144 135  30 180  81 187  78 233 106  87 212  99  45 207 194 198  25  59
 218 161 155 196  19 201  67 225 253 247 197 114 152 231  34 221 245 235
 146 211 242  31 210  35 133  46 181  71   2 213 141 217  43  90  23 234
 222 139 183 149 115  70  51 239 206 184]
Test Accuracy: 55.84%


Epoch   216:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.68%


Epoch   217:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.19%


Epoch   218:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.32%


Epoch   219:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 52.43%


Epoch   220:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.029788676649332047
Density indices: [250 240  21 253 159 204 180  60 168 210 182 186  75  64 144 252 134  80
  53  88 161 248 231 153 238  89  29 110  62  24 213  79 155  70 119 171
 114  63  50 207 109 245  86  43 146  52 234 188 243  58 181 135 177 221
 150 163   0 194 192 102  20  19  38  61]
Test Accuracy: 56.32%


Epoch   221:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.43%


Epoch   222:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.82%


Epoch   223:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.93%


Epoch   224:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 53.11%


Epoch   225:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.02650558203458786
Density indices: [114  49  26   9 237  70 171 127 200 236 172  89  16  37 208 243 140 119
  73   8 166 203  64 207 196 108 116   0 222  38  91  74 111   4 255 205
 126  18  84 194 135 132 195  63 163  62 129 157  96  68  33  79 216 248
  13 101  54 238  46 245 113 229  25 155]
Test Accuracy: 52.06%


Epoch   226:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.58%


Epoch   227:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.95%


Epoch   228:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 54.86%


Epoch   229:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.38%


Epoch   230:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.025249261409044266
Density indices: [220  46 144  99  82 175 185 196 198 130 104 125 155  77 106 171  91  60
  73  51 203 205 113  31 137 120 237  41  92   4 123 100  75 208 164 210
  48  58 145  72 191 174  12 231 255   8 121  97  79 206 245 182 103 163
 246  37  42  17 157 109 128  96 243   3]
Test Accuracy: 55.54%


Epoch   231:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.15%


Epoch   232:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.0%


Epoch   233:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.93%


Epoch   234:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.37%


Epoch   235:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.025006037205457687
Density indices: [ 55 185  68   2 219 248   9 197 222 224   7  62 133 111  94 160 175 183
 117 253  23 238 140 157 167  59 131 106 190 182 184 235  69 123  78  28
 146  66  87 121 230  31  56 145 199 247 119 196 128  57 188 168 102 151
  90  12 169  53 161  54  13 127 226  46]
Test Accuracy: 56.78%


Epoch   236:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.9%


Epoch   237:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 55.85%


Epoch   238:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.31%


Epoch   239:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.76%


Epoch   240:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.023636426776647568
Density indices: [ 13 213 205 148 152   1  56  74  47  42 231 102  95  35 158 133  26 191
 104  80  66 189  17 208  14 228 226 185  49   0 200  90 244 173 184  46
  27  48 107 194 227  62 186 230  97 127 175  82  33 206  36 250 110  92
 219  15 249   3 187  84  70  41 251 245]
Test Accuracy: 57.0%


Epoch   241:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.15%


Epoch   242:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.18%


Epoch   243:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.68%


Epoch   244:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.72%


Epoch   245:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.026297403499484062
Density indices: [210 141  83 238 200 112 122 101  12  54  84   8 193 125  37  62 175  31
 234 233 108 241 194  72 179  75  76 133 245  23  92 126 100 227  86 134
  48  97  36 153  74 216   5 209 161  21 223  67 131 219 187 148  14  39
 184  18 138 154 124  85  56 102 215 174]
Test Accuracy: 57.48%


Epoch   246:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.94%


Epoch   247:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.17%


Epoch   248:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.23%


Epoch   249:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.81%


Epoch   250:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03317137435078621
Density indices: [246 171  70  23   6  88 211 113 153 168  83 221 244 124 195  58 121 150
  19 115 134 226   2  14  18 190 245 110 157  24 155 139  22 143 130 209
 192 250 252  60  80 166 107  30 162 200  36  82  43 146 223  38 151 141
 137 179 215  52  89  95  42 201 208 133]
Test Accuracy: 58.18%


Epoch   251:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.8%


Epoch   252:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.16%


Epoch   253:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.46%


Epoch   254:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.53%


Epoch   255:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.014252269640564919
Density indices: [235 163  28 253 192 216 181  59 244 174 197 131 170 245 111  81  68  22
  18   8  50 113 252 217  95  36  31 118 105 183   9 191 132  20  44 162
  58 133 226 164 173 115   4 168 180  30 242  32 147  75  46 212 187  96
 229   3 175  14 182 243  52 137 206 193]
Test Accuracy: 57.37%


Epoch   256:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.44%


Epoch   257:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.2%


Epoch   258:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.54%


Epoch   259:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.29%


Epoch   260:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.03426231071352959
Density indices: [124   5 196 143 142 131 128 235  84 125  13 120 194  46 119  32 236  81
  94 214  47  56  69  48 230 139 207 127  77 208 241   3 217  10  53  92
 211 188 154 216 162 135 147 245 137 160 242 225 178  42  51 232 226  96
  25 187   6 146  19 102 212 184  62 103]
Test Accuracy: 57.58%


Epoch   261:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.86%


Epoch   262:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.73%


Epoch   263:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.16%


Epoch   264:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.77%


Epoch   265:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.018886957317590714
Density indices: [ 12 231 131 187 183 210 191 181 146  50 184   2 242   9 249 169 215  96
 174  82 171 133 244   7  91 125 177 197 103  43 106  69 222 224 152 217
  32 124  42 164 206  13  95  75  56  19  83  37 156 148 162 216 221  25
  48  21  92  51 239 115 200 135 236 158]
Test Accuracy: 56.82%


Epoch   266:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.85%


Epoch   267:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.25%


Epoch   268:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.13%


Epoch   269:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.44%


Epoch   270:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.02475501410663128
Density indices: [134  88 207 231 192 165  31 152 196 122 142 255 200  54  20 184 157 206
  15 247  86   2 125  81 166  13 101 161  27 253 191 224  19 137 129  43
  28  57  14  37 178 250  82 245  80  34  72 209 228 151  30  97  10  78
  46  95 124  50 168 227 119 180 183 203]
Test Accuracy: 57.8%


Epoch   271:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.97%


Epoch   272:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.05%


Epoch   273:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.26%


Epoch   274:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.2%


Epoch   275:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.028590135276317596
Density indices: [ 83 213 251 150 109 171 122 173 195 252  52  85 146 206 168 157 135 151
  36 219  25 142 183  15   9 113 240  70  58 100  76   1  49  48 178  11
  64 103  72 114  24 190 130 194 163 214 211 249 197  32 234 126 119 201
 199 137 172 205  60 203 120 216 236 212]
Test Accuracy: 57.99%


Epoch   276:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.63%


Epoch   277:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.68%


Epoch   278:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.75%


Epoch   279:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.88%


Epoch   280:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.029789693653583527
Density indices: [185  18  93 126  75  70 166 243 205  97  67  73 193 254 183  77 179  63
  69 213 237 241 155 216 218 225 176  39  37 135  38  98 239 159 154  17
 168 220 203  88 107  72   1 242  21 235 147 164 163 134 157  19 212 255
 149  34 122  32 119  71 106  41  36 110]
Test Accuracy: 57.92%


Epoch   281:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.11%


Epoch   282:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.1%


Epoch   283:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.06%


Epoch   284:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.98%


Epoch   285:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.033307310193777084
Density indices: [160 131 245  53 235  96 215 136   1 229  65 177 208 218 227 163 192 125
  25 176 112  82 221   3 119 180 148 201 100 172  85  35 123  78  23  98
  57 212 198  50   0 188 157 223  31 143  36 109 169 104  72  97  73  28
 181 238 190 113 138  88 137 196 179 144]
Test Accuracy: 58.02%


Epoch   286:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.09%


Epoch   287:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.11%


Epoch   288:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.08%


Epoch   289:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.17%


Epoch   290:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.02621760219335556
Density indices: [ 37  87 133 183 150  43 222 217 215  36 200  51 167  59  46  23 241  17
 237 207 210 214  93  26 142 137   2 163 112 189 141 175  28  85 248 243
  90  45 130 249 121 202  49 253  56 206 195 213   0 216  62  86 131  50
  54 153 208 104 159  99  68 204  77  74]
Test Accuracy: 58.16%


Epoch   291:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.14%


Epoch   292:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.11%


Epoch   293:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.13%


Epoch   294:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.13%


Epoch   295:   0%|          | 0/12 [00:00<?, ?batch/s]

Information density: 0.023918725550174713
Density indices: [  0 215 218 105 255 202 173  69 160 153  89 116 244 110  63  36 191  44
 126 207  66 192  80 118  71  32  31 197 146 124 150 236   4  82 216 106
 217  15   5 190 168 222 156 233 188  65  60 159   7 209 204 211 133  79
 147 231 243  61  46 174 100 210 246  94]
Test Accuracy: 58.13%


Epoch   296:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.12%


Epoch   297:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.11%


Epoch   298:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.1%


Epoch   299:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.1%


Epoch   300:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.1%


Epoch   301:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.1%


Epoch   302:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.12%


Epoch   303:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.13%


Epoch   304:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.13%


Epoch   305:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.09%


Epoch   306:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.09%


Epoch   307:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.04%


Epoch   308:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.05%


Epoch   309:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.98%


Epoch   310:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.98%


Epoch   311:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.95%


Epoch   312:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.99%


Epoch   313:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.04%


Epoch   314:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.13%


Epoch   315:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.05%


Epoch   316:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.96%


Epoch   317:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.95%


Epoch   318:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.73%


Epoch   319:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.86%


Epoch   320:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.76%


Epoch   321:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.67%


Epoch   322:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.72%


Epoch   323:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.82%


Epoch   324:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.14%


Epoch   325:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.16%


Epoch   326:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.09%


Epoch   327:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.89%


Epoch   328:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.44%


Epoch   329:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.38%


Epoch   330:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.97%


Epoch   331:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.39%


Epoch   332:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.38%


Epoch   333:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.39%


Epoch   334:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.71%


Epoch   335:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.14%


Epoch   336:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.33%


Epoch   337:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.09%


Epoch   338:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.73%


Epoch   339:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.81%


Epoch   340:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.56%


Epoch   341:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.89%


Epoch   342:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.33%


Epoch   343:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.25%


Epoch   344:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.13%


Epoch   345:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.29%


Epoch   346:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.67%


Epoch   347:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.68%


Epoch   348:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.66%


Epoch   349:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.53%


Epoch   350:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.02%


Epoch   351:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.77%


Epoch   352:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.23%


Epoch   353:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.97%


Epoch   354:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.81%


Epoch   355:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.72%


Epoch   356:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 58.06%


Epoch   357:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.99%


Epoch   358:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.9%


Epoch   359:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.03%


Epoch   360:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.13%


Epoch   361:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.37%


Epoch   362:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.53%


Epoch   363:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 56.99%


Epoch   364:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 57.54%


Epoch   365:   0%|          | 0/12 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [15]:
idx_density

array([114,  49,  26,   9, 237,  70, 171, 127, 200, 236, 172,  89,  16,
        37, 208, 243, 140, 119,  73,   8, 166, 203,  64, 207, 196, 108,
       116,   0, 222,  38,  91,  74, 111,   4, 255, 205, 126,  18,  84,
       194, 135, 132, 195,  63, 163,  62, 129, 157,  96,  68,  33,  79,
       216, 248,  13, 101,  54, 238,  46, 245, 113, 229,  25, 155])

In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_10_fixmatch_AL.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}% - Active Learning")
fig1.savefig(f"./figures/test_score_{SUBSET_PROP}_fixmatch_AL.png")

### III.2 Fixmatch on 5% train data

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 50
SUBSET_PROP = 0.05

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = K.ImageSequential(
    K.RandomHorizontalFlip(p=0.50), 
    K.RandomAffine(degrees=0, translate=(0.125, 0.125)),
)

strong_transform = K.ImageSequential(
    K.auto.RandAugment(n=2, m=10), # randaugment + cutout
)

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


# Define the cosine learning rate decay function
lr_lambda = lambda step: LR * torch.cos(torch.tensor((7 * torch.pi * (step)) / (16 * EPOCHS))) * 100 / 3

# Create a learning rate scheduler with the cosine decay function
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# scheduler = None

In [None]:
print("Start training")

train_losses = []
test_losses = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    max_confidence = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {epoch: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # normalize
        weak_labeled_inputs = normalize(data=weak_labeled_inputs, mean=mean, std=std)
        weak_unlabeled_inputs = normalize(data=weak_unlabeled_inputs, mean=mean, std=std)
        strong_unlabeled_inputs = normalize(data=strong_unlabeled_inputs, mean=mean, std=std)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)
            # outputs = torch.softmax(outputs, dim=1)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)
            # labeled_outputs = torch.softmax(labeled_outputs, dim=1)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        

        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled
        max_confidence = max(max_confidence, max_qb.max().item())

        

        # update progress bar
        pbar.set_postfix({
            "total loss": loss.item(),
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "confidence": max_confidence,
            "n_unlabeled": running_n_unlabeled,
            "lr": optimizer.param_groups[0]['lr'].item()
        })

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)


In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_5_fixmatch.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}%")
fig1.savefig(f"./figures/test_score_{SUBSET_PROP}_fixmatch.png")

### III.3 Fixmatch on 1% train data

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 50
SUBSET_PROP = 0.01

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = K.ImageSequential(
    K.RandomHorizontalFlip(p=0.50), 
    K.RandomAffine(degrees=0, translate=(0.125, 0.125)),
)

strong_transform = K.ImageSequential(
    K.auto.RandAugment(n=2, m=10), # randaugment + cutout
)

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


# Define the cosine learning rate decay function
lr_lambda = lambda step: LR * torch.cos(torch.tensor((7 * torch.pi * (step)) / (16 * EPOCHS))) * 100 / 3

# Create a learning rate scheduler with the cosine decay function
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# scheduler = None

In [None]:
print("Start training")

train_losses = []
test_losses = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    max_confidence = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {epoch: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # normalize
        weak_labeled_inputs = normalize(data=weak_labeled_inputs, mean=mean, std=std)
        weak_unlabeled_inputs = normalize(data=weak_unlabeled_inputs, mean=mean, std=std)
        strong_unlabeled_inputs = normalize(data=strong_unlabeled_inputs, mean=mean, std=std)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)
            # outputs = torch.softmax(outputs, dim=1)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)
            # labeled_outputs = torch.softmax(labeled_outputs, dim=1)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        

        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled
        max_confidence = max(max_confidence, max_qb.max().item())

        

        # update progress bar
        pbar.set_postfix({
            "total loss": loss.item(),
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "confidence": max_confidence,
            "n_unlabeled": running_n_unlabeled,
            "lr": optimizer.param_groups[0]['lr'].item()
        })

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)


In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_1_fixmatch.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}%")
fig1.savefig(f"./figures/test_score_{SUBSET_PROP}_fixmatch.png")