# Projet Fixmatch

In [None]:
# !pip install torchview torchsummary torchvision kornia torchmetrics matplotlib tqdm path graphviz opencv-python scikit-learn optuna

In [1]:
# deep learning
import torch
import torch.nn as nn
from torch.distributions.transforms import LowerCholeskyTransform
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, Dataset

# vizualisation
import torchsummary

# transforms
import torchvision.transforms as T
import kornia.augmentation as K
from kornia.enhance import normalize
# from torchvision.transforms import RandAugment

# metrics
from torchmetrics import Accuracy

# torchvision
import torchvision
from torchvision import transforms

# plotting
import matplotlib.pyplot as plt
from torchview import draw_graph

from IPython.display import display
from IPython.core.display import SVG, HTML

from tqdm.auto import tqdm

# typing
from typing import Callable

from utils import plot_images, plot_transform
from model import ConvNN, display_model

# os
import os
import path

import random
import numpy as np 

# transformations
# import transform as T
# from randaugment import RandomAugment

# typing
from typing import Callable, List, Tuple

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics.pairwise import cosine_similarity

from randaugment import RandAugmentMC

%load_ext autoreload
%autoreload 2

In [2]:
DEFAULT_RANDOM_SEED = 2021

def seedBasic(seed=DEFAULT_RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
# torch random seed
import torch
def seedTorch(seed=DEFAULT_RANDOM_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
# basic + tensorflow + torch 
def seedEverything(seed=DEFAULT_RANDOM_SEED):
    seedBasic(seed)
    seedTorch(seed)

In [3]:
# Set device
if ((int(torch.__version__.split(".")[0]) >= 2) or (int(torch.__version__.split(".")[1]) >= 13)) and torch.has_mps:
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [4]:
IMG_SHAPE = (3, 32, 32)
# See Table 4
TAU = 0.95 #! 0.95 in the paper
LAMBDA_U = 1 # 1
MU = 7
BATCH_SIZE = 64
LR = 0.03
BETA = 0.9
WEIGHT_DECAY = 0.02 # mixmatch AL paper
# WEIGHT_DECAY = 0.0005 # default paper
BETA_DENSITY = 1

In [5]:
class ConvNN(nn.Module):
    """
    Simple CNN for CIFAR10
    """
    
    def __init__(self):
        super().__init__()
        self.conv_32 = nn.Conv2d(3, 32, kernel_size=3, padding='same')
        self.conv_64 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.conv_96 = nn.Conv2d(64, 96, kernel_size=3, padding='same')
        self.conv_128 = nn.Conv2d(96, 128, kernel_size=3, padding='same')
        self.fc_512 = nn.Linear(512, 512)
        self.fc_10 = nn.Linear(512, 10)
        self.max_pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU(inplace=True)
        self.flatten = nn.Flatten()
        # self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_32(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_64(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_96(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_128(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.flatten(x)
        x = self.fc_512(x)
        x = self.relu(x)
        x = self.fc_10(x)
        # x = self.softmax(x)

        return x

In [6]:
def compute_mean_std(trainLoader) -> Tuple[List[float], List[float]]:
    # initialize the list of means and stds
    mean, std = torch.zeros(3), torch.zeros(3)

    # iterate over the dataset and compute the sum of each channel
    for images, _ in trainLoader:
        mean+= torch.mean(images, dim=[0,2,3])
        std+= torch.std(images, dim=[0,2,3])
    
    # compute the mean and std
    mean = mean/len(trainLoader)
    std = std/len(trainLoader)

    return mean, std

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

if not os.path.exists('./data/mean.pt'):
    mean, std = compute_mean_std(trainloader)
    torch.save(mean, 'data/mean.pt')
    torch.save(std, 'data/std.pt')
else:
    mean, std = torch.load('./data/mean.pt'), torch.load('./data/std.pt')

# to numpy
mean, std = mean.numpy(), std.numpy()

print(f"mean: {mean}, std: {std}")


testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
mean: [0.49135667 0.48212072 0.4465159 ], std: [0.24632095 0.24274482 0.26062676]
Files already downloaded and verified


In [7]:
torch_models = 'torch_models' 
if not os.path.exists(torch_models):
    os.makedirs(torch_models)

## IV. Semi-Supervised Learning: Fixmatch with Active Learning

### IV.1 Fixmatch on 10% train data with Active Learning

In [8]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

TARGET_PROP = 0.10
SUBSET_PROP = 0.005
K_SAMPLES = 50

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=False,
    num_workers=0
)



In [9]:
# transformations
weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=0, translate=(0.125, 0.125)),
    transforms.Normalize(mean, std),
])

strong_transform = transforms.Compose([
    RandAugmentMC(n=2, m=10),
    transforms.Normalize(mean, std)
])

In [10]:
def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat_max = torch.max(qb, dim=1)
        

        idx_max = max_qb > TAU
        qb_hat_max = qb_hat_max[idx_max]
        
    return qb_hat_max.detach(), idx_max, max_qb.detach()

In [11]:
model = ConvNN().to(device)

EPOCHS = 5_000

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')
true_unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=0, last_epoch=-1)

In [12]:
def information_density(
    model: ConvNN,
    unlabeled_dataloader: DataLoader,
    k_samp: int,
    K_transform: int = 5) -> torch.Tensor:

    model.eval()

    with torch.no_grad():
        for i, (inputs, _) in enumerate(unlabeled_dataloader):
            inputs = inputs.to(device)

            logits_mean = torch.zeros((len(inputs), 10)).to(device)

            for k in range(K_transform):
                # compute weak transform of inputs
                weak_transformed_inputs = weak_transform(inputs)

                # compute predictions of weak transformed inputs
                logits = model(weak_transformed_inputs)

                logits_mean += logits
            
            logits_mean /= K_transform

            qb = torch.softmax(logits_mean, dim=1).to(device)
        
        # get top 2 values of qb along dim=1
        qb1, qb2 = torch.topk(qb, k=2, dim=1, sorted=True)[0][:, 0], torch.topk(qb, k=2, dim=1, sorted=True)[0][:, 1]

        uncertainty = 1 - (qb1 - qb2)

        # get top k_samp indices
        uncert, idx = torch.topk(uncertainty, k=k_samp, dim=0, sorted=True)

        return idx, torch.mean(uncert).item()

In [13]:
# indices, uncert = information_density(model, unlabeled_dataloader, k_samp=K_SAMPLES)

In [14]:
def least_confidence(
    model: ConvNN,
    unlabeled_dataloader: torch.utils.data.DataLoader,
    k_samp: int) -> torch.Tensor:
    
    probs = []

    model.eval()

    with torch.no_grad():
        for i, (inputs, _) in enumerate(unlabeled_dataloader):
            inputs = inputs.to(device)

            # inputs transforms
            weak_transformed_inputs = weak_transform(inputs)

            logits = model(weak_transformed_inputs)
            qb = torch.softmax(logits, dim=1)
            probs.append(qb.cpu().data)
    
    # transform to tensor
    probs_tensor = torch.cat(probs, dim=0)

    U = probs_tensor.max(1)[0]

    selected_indices = U.sort()[1][:k_samp]

    return selected_indices, torch.mean(torch.sort(U, descending=False)[0][:k_samp]).item()


# Create a new labeled dataset using active learning
def create_labeled_dataset_active_learning(dataset, selected_indices):
    dataset = torch.utils.data.Subset(dataset, selected_indices)
    return dataset

In [15]:
print("Start training")

current_prop = SUBSET_PROP

train_losses = []
test_losses = []
added_samp = 0
uncertainty = 0

AL_activated = False
main_activated = True
train_algo = 20_000

step = 0
for j in range(EPOCHS):
    model.train()

    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0

    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {j: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        step += len(labeled_inputs)
        

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()
            
        # backward pass + optimize
        loss.backward()

        # clamp gradients, just in case
        # for p in filter(lambda p: p.grad is not None, model.parameters()): p.grad.data.clamp_(min=-.1, max=.1)
        
        optimizer.step()

        
        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled 

        # update progress bar
        pbar.set_postfix({
            "lab_loss": labeled_loss.item(),
            "unlab_loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg_confidence": torch.mean(max_qb).item(),
            "n_unlabeled": running_n_unlabeled,
            "lab_prop": current_prop,
            "uncertainty": uncertainty,
            "lr": optimizer.param_groups[0]['lr'],
            "[main_activ, AL_activ]": [main_activated, AL_activated],
            "step": step,
        })

    # start adding labels after 50 epochs for 50 epochs
    # if step >= 262_144 and main_activated:
    if step >= train_algo:
        AL_activated = True
        main_activated = False
        step = train_algo//4

    if AL_activated and step >= train_algo//4:
        if current_prop < TARGET_PROP:
            # compute information density
            selected_indices, uncertainty = information_density(model, unlabeled_dataloader, K_SAMPLES)

            # select indices from unlabeled dataset
            trainset_sup_new = create_labeled_dataset_active_learning(trainset_unsup, selected_indices)

            # concat new trainset with labeled trainset
            trainset_sup = torch.utils.data.ConcatDataset([trainset_sup, trainset_sup_new])

            # create labeled dataloader
            labeled_dataloader = torch.utils.data.DataLoader(trainset_sup, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

            current_prop = len(trainset_sup) / len(trainset)

            step = 0
    

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()
    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)

Start training


Epoch     0:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.04%


Epoch     1:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     2:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     3:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     4:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     5:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     6:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     7:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     8:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     9:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch    10:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 12.21%


Epoch    11:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 16.34%


Epoch    12:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 18.23%


Epoch    13:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 19.5%


Epoch    14:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 18.99%


Epoch    15:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 20.91%


Epoch    16:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 19.24%


Epoch    17:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 20.6%


Epoch    18:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 21.79%


Epoch    19:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 21.47%


Epoch    20:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 21.57%


Epoch    21:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 21.51%


Epoch    22:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 21.0%


Epoch    23:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 22.52%


Epoch    24:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.73%


Epoch    25:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.93%


Epoch    26:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.58%


Epoch    27:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.02%


Epoch    28:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.81%


Epoch    29:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.09%


Epoch    30:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.2%


Epoch    31:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.01%


Epoch    32:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.76%


Epoch    33:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.16%


Epoch    34:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.06%


Epoch    35:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.97%


Epoch    36:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.22%


Epoch    37:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.51%


Epoch    38:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.19%


Epoch    39:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.03%


Epoch    40:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.85%


Epoch    41:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.85%


Epoch    42:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.52%


Epoch    43:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.76%


Epoch    44:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.79%


Epoch    45:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.94%


Epoch    46:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.56%


Epoch    47:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.7%


Epoch    48:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 24.74%


Epoch    49:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.68%


Epoch    50:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.0%


Epoch    51:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.13%


Epoch    52:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.85%


Epoch    53:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.1%


Epoch    54:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.3%


Epoch    55:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.64%


Epoch    56:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 29.06%


Epoch    57:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.64%


Epoch    58:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.95%


Epoch    59:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 28.22%


Epoch    60:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 28.28%


Epoch    61:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.36%


Epoch    62:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.2%


Epoch    63:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.48%


Epoch    64:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.33%


Epoch    65:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.22%


Epoch    66:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 29.26%


Epoch    67:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.77%


Epoch    68:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.03%


Epoch    69:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.39%


Epoch    70:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 29.04%


Epoch    71:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 22.12%


Epoch    72:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 28.5%


Epoch    73:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 25.29%


Epoch    74:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 28.26%


Epoch    75:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 28.68%


Epoch    76:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 29.94%


Epoch    77:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 23.74%


Epoch    78:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 27.16%


Epoch    79:   0%|          | 0/4 [00:00<?, ?batch/s]

Test Accuracy: 26.92%


Epoch    80:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 28.48%


Epoch    81:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 28.31%


Epoch    82:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 27.95%


Epoch    83:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 27.15%


Epoch    84:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 25.68%


Epoch    85:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 31.26%


Epoch    86:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 28.31%


Epoch    87:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 25.13%


Epoch    88:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 30.94%


Epoch    89:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 29.48%


Epoch    90:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 22.23%


Epoch    91:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 26.91%


Epoch    92:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 26.19%


Epoch    93:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 27.04%


Epoch    94:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 28.62%


Epoch    95:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 26.22%


Epoch    96:   0%|          | 0/5 [00:00<?, ?batch/s]

Test Accuracy: 29.15%


Epoch    97:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 25.94%


Epoch    98:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 28.51%


Epoch    99:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 30.37%


Epoch   100:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 30.74%


Epoch   101:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 30.33%


Epoch   102:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 31.74%


Epoch   103:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 32.23%


Epoch   104:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 29.38%


Epoch   105:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 29.67%


Epoch   106:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 31.53%


Epoch   107:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 31.15%


Epoch   108:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 31.51%


Epoch   109:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 29.43%


Epoch   110:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 31.69%


Epoch   111:   0%|          | 0/6 [00:00<?, ?batch/s]

Test Accuracy: 29.91%


Epoch   112:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 28.34%


Epoch   113:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 29.69%


Epoch   114:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 26.6%


Epoch   115:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 30.24%


Epoch   116:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 30.83%


Epoch   117:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 27.25%


Epoch   118:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 31.99%


Epoch   119:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 30.16%


Epoch   120:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 31.78%


Epoch   121:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 30.92%


Epoch   122:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 30.5%


Epoch   123:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 32.22%


Epoch   124:   0%|          | 0/7 [00:00<?, ?batch/s]

Test Accuracy: 24.11%


Epoch   125:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.81%


Epoch   126:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 29.55%


Epoch   127:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 32.34%


Epoch   128:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 29.98%


Epoch   129:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.56%


Epoch   130:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 30.88%


Epoch   131:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 18.92%


Epoch   132:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 27.15%


Epoch   133:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 27.79%


Epoch   134:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 34.0%


Epoch   135:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.08%


Epoch   136:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 27.43%


Epoch   137:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 33.02%


Epoch   138:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 34.54%


Epoch   139:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 32.7%


Epoch   140:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 32.65%


Epoch   141:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 29.41%


Epoch   142:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 30.89%


Epoch   143:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 33.42%


Epoch   144:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 30.72%


Epoch   145:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.09%


Epoch   146:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 32.78%


Epoch   147:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 31.71%


Epoch   148:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 31.99%


Epoch   149:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 30.22%


Epoch   150:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 31.64%


Epoch   151:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 32.7%


Epoch   152:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 31.29%


Epoch   153:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 30.58%


Epoch   154:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 33.6%


Epoch   155:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 31.39%


Epoch   156:   0%|          | 0/9 [00:00<?, ?batch/s]

Test Accuracy: 32.23%


Epoch   157:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 30.31%


Epoch   158:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 32.85%


Epoch   159:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 29.98%


Epoch   160:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 29.22%


Epoch   161:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 32.34%


Epoch   162:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 30.95%


Epoch   163:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 30.69%


Epoch   164:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 33.25%


Epoch   165:   0%|          | 0/10 [00:00<?, ?batch/s]

Test Accuracy: 33.47%


Epoch   166:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 30.71%


Epoch   167:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 28.97%


Epoch   168:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 31.47%


Epoch   169:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 31.48%


Epoch   170:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 33.7%


Epoch   171:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 31.33%


Epoch   172:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 33.54%


Epoch   173:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 30.35%


Epoch   174:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 34.69%


Epoch   175:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 31.34%


Epoch   176:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 35.2%


Epoch   177:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 33.5%


Epoch   178:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 30.78%


Epoch   179:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 33.5%


Epoch   180:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 32.91%


Epoch   181:   0%|          | 0/11 [00:00<?, ?batch/s]

Test Accuracy: 33.94%


Epoch   182:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 35.77%


Epoch   183:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 31.29%


Epoch   184:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 33.69%


Epoch   185:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 32.96%


Epoch   186:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 32.49%


Epoch   187:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 30.52%


Epoch   188:   0%|          | 0/12 [00:00<?, ?batch/s]

Test Accuracy: 34.46%


Epoch   189:   0%|          | 0/13 [00:00<?, ?batch/s]

Test Accuracy: 27.69%


Epoch   190:   0%|          | 0/13 [00:00<?, ?batch/s]

Test Accuracy: 34.79%


Epoch   191:   0%|          | 0/13 [00:00<?, ?batch/s]

Test Accuracy: 33.59%


Epoch   192:   0%|          | 0/13 [00:00<?, ?batch/s]

Test Accuracy: 33.37%


Epoch   193:   0%|          | 0/13 [00:00<?, ?batch/s]

Test Accuracy: 33.65%


Epoch   194:   0%|          | 0/13 [00:00<?, ?batch/s]

Test Accuracy: 31.41%


Epoch   195:   0%|          | 0/13 [00:00<?, ?batch/s]

Test Accuracy: 33.0%


Epoch   196:   0%|          | 0/14 [00:00<?, ?batch/s]

Test Accuracy: 33.42%


Epoch   197:   0%|          | 0/14 [00:00<?, ?batch/s]

Test Accuracy: 32.57%


Epoch   198:   0%|          | 0/14 [00:00<?, ?batch/s]

Test Accuracy: 30.7%


Epoch   199:   0%|          | 0/14 [00:00<?, ?batch/s]

Test Accuracy: 32.52%


Epoch   200:   0%|          | 0/14 [00:00<?, ?batch/s]

Test Accuracy: 33.37%


Epoch   201:   0%|          | 0/14 [00:00<?, ?batch/s]

Test Accuracy: 32.06%


Epoch   202:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 31.92%


Epoch   203:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 27.41%


Epoch   204:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 29.63%


Epoch   205:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 31.43%


Epoch   206:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 32.88%


Epoch   207:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 33.59%


Epoch   208:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 31.62%


Epoch   209:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 35.3%


Epoch   210:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 33.11%


Epoch   211:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 31.72%


Epoch   212:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 33.11%


Epoch   213:   0%|          | 0/15 [00:00<?, ?batch/s]

Test Accuracy: 34.4%


Epoch   214:   0%|          | 0/16 [00:00<?, ?batch/s]

Test Accuracy: 35.43%


Epoch   215:   0%|          | 0/16 [00:00<?, ?batch/s]

Test Accuracy: 34.0%


Epoch   216:   0%|          | 0/16 [00:00<?, ?batch/s]

Test Accuracy: 32.57%


Epoch   217:   0%|          | 0/16 [00:00<?, ?batch/s]

Test Accuracy: 30.48%


Epoch   218:   0%|          | 0/16 [00:00<?, ?batch/s]

Test Accuracy: 31.96%


Epoch   219:   0%|          | 0/17 [00:00<?, ?batch/s]

Test Accuracy: 32.54%


Epoch   220:   0%|          | 0/17 [00:00<?, ?batch/s]

Test Accuracy: 28.83%


Epoch   221:   0%|          | 0/17 [00:00<?, ?batch/s]

Test Accuracy: 33.59%


Epoch   222:   0%|          | 0/17 [00:00<?, ?batch/s]

Test Accuracy: 32.99%


Epoch   223:   0%|          | 0/17 [00:00<?, ?batch/s]

Test Accuracy: 27.79%


Epoch   224:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 34.74%


Epoch   225:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 30.97%


Epoch   226:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 35.39%


Epoch   227:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 31.94%


Epoch   228:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 32.68%


Epoch   229:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 35.03%


Epoch   230:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 36.21%


Epoch   231:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 32.69%


Epoch   232:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 33.36%


Epoch   233:   0%|          | 0/18 [00:00<?, ?batch/s]

Test Accuracy: 31.83%


Epoch   234:   0%|          | 0/19 [00:00<?, ?batch/s]

Test Accuracy: 36.06%


Epoch   235:   0%|          | 0/19 [00:00<?, ?batch/s]

Test Accuracy: 33.9%


Epoch   236:   0%|          | 0/19 [00:00<?, ?batch/s]

Test Accuracy: 35.11%


Epoch   237:   0%|          | 0/19 [00:00<?, ?batch/s]

Test Accuracy: 34.07%


Epoch   238:   0%|          | 0/19 [00:00<?, ?batch/s]

Test Accuracy: 33.44%


Epoch   239:   0%|          | 0/20 [00:00<?, ?batch/s]

Test Accuracy: 33.81%


Epoch   240:   0%|          | 0/20 [00:00<?, ?batch/s]

Test Accuracy: 35.16%


Epoch   241:   0%|          | 0/20 [00:00<?, ?batch/s]

Test Accuracy: 33.58%


Epoch   242:   0%|          | 0/20 [00:00<?, ?batch/s]

Test Accuracy: 33.33%


Epoch   243:   0%|          | 0/21 [00:00<?, ?batch/s]

Test Accuracy: 36.14%


Epoch   244:   0%|          | 0/21 [00:00<?, ?batch/s]

Test Accuracy: 27.65%


Epoch   245:   0%|          | 0/21 [00:00<?, ?batch/s]

Test Accuracy: 29.72%


Epoch   246:   0%|          | 0/21 [00:00<?, ?batch/s]

Test Accuracy: 32.98%


Epoch   247:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 34.51%


Epoch   248:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 34.52%


Epoch   249:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 30.03%


Epoch   250:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 31.89%


Epoch   251:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 32.51%


Epoch   252:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 33.78%


Epoch   253:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 34.42%


Epoch   254:   0%|          | 0/22 [00:00<?, ?batch/s]

Test Accuracy: 33.11%


Epoch   255:   0%|          | 0/23 [00:00<?, ?batch/s]

Test Accuracy: 34.54%


Epoch   256:   0%|          | 0/23 [00:00<?, ?batch/s]

Test Accuracy: 34.92%


Epoch   257:   0%|          | 0/23 [00:00<?, ?batch/s]

Test Accuracy: 30.85%


Epoch   258:   0%|          | 0/23 [00:00<?, ?batch/s]

Test Accuracy: 36.32%


Epoch   259:   0%|          | 0/24 [00:00<?, ?batch/s]

Test Accuracy: 34.3%


Epoch   260:   0%|          | 0/24 [00:00<?, ?batch/s]

Test Accuracy: 33.42%


Epoch   261:   0%|          | 0/24 [00:00<?, ?batch/s]

Test Accuracy: 32.38%


Epoch   262:   0%|          | 0/24 [00:00<?, ?batch/s]

Test Accuracy: 34.47%


Epoch   263:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 34.16%


Epoch   264:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 33.04%


Epoch   265:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 28.63%


Epoch   266:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 30.0%


Epoch   267:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 35.02%


Epoch   268:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 33.18%


Epoch   269:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 34.98%


Epoch   270:   0%|          | 0/25 [00:00<?, ?batch/s]

Test Accuracy: 34.22%


Epoch   271:   0%|          | 0/26 [00:00<?, ?batch/s]

Test Accuracy: 35.52%


Epoch   272:   0%|          | 0/26 [00:00<?, ?batch/s]

Test Accuracy: 34.46%


Epoch   273:   0%|          | 0/26 [00:00<?, ?batch/s]

Test Accuracy: 32.64%


Epoch   274:   0%|          | 0/26 [00:00<?, ?batch/s]

Test Accuracy: 32.0%


Epoch   275:   0%|          | 0/27 [00:00<?, ?batch/s]

Test Accuracy: 33.76%


Epoch   276:   0%|          | 0/27 [00:00<?, ?batch/s]

Test Accuracy: 33.14%


Epoch   277:   0%|          | 0/27 [00:00<?, ?batch/s]

Test Accuracy: 36.26%


Epoch   278:   0%|          | 0/28 [00:00<?, ?batch/s]

Test Accuracy: 33.34%


Epoch   279:   0%|          | 0/28 [00:00<?, ?batch/s]

Test Accuracy: 30.93%


Epoch   280:   0%|          | 0/28 [00:00<?, ?batch/s]

Test Accuracy: 32.07%


Epoch   281:   0%|          | 0/29 [00:00<?, ?batch/s]

Test Accuracy: 26.17%


Epoch   282:   0%|          | 0/29 [00:00<?, ?batch/s]

Test Accuracy: 30.51%


Epoch   283:   0%|          | 0/29 [00:00<?, ?batch/s]

Test Accuracy: 33.28%


Epoch   284:   0%|          | 0/29 [00:00<?, ?batch/s]

Test Accuracy: 30.56%


Epoch   285:   0%|          | 0/29 [00:00<?, ?batch/s]

Test Accuracy: 31.63%


Epoch   286:   0%|          | 0/29 [00:00<?, ?batch/s]

Test Accuracy: 33.2%


Epoch   287:   0%|          | 0/30 [00:00<?, ?batch/s]

Test Accuracy: 35.28%


Epoch   288:   0%|          | 0/30 [00:00<?, ?batch/s]

Test Accuracy: 33.6%


Epoch   289:   0%|          | 0/30 [00:00<?, ?batch/s]

Test Accuracy: 25.68%


Epoch   290:   0%|          | 0/31 [00:00<?, ?batch/s]

Test Accuracy: 35.34%


Epoch   291:   0%|          | 0/31 [00:00<?, ?batch/s]

Test Accuracy: 35.01%


Epoch   292:   0%|          | 0/31 [00:00<?, ?batch/s]

Test Accuracy: 33.16%


Epoch   293:   0%|          | 0/32 [00:00<?, ?batch/s]

Test Accuracy: 33.25%


Epoch   294:   0%|          | 0/32 [00:00<?, ?batch/s]

Test Accuracy: 31.46%


Epoch   295:   0%|          | 0/32 [00:00<?, ?batch/s]

Test Accuracy: 32.67%


Epoch   296:   0%|          | 0/33 [00:00<?, ?batch/s]

Test Accuracy: 34.04%


Epoch   297:   0%|          | 0/33 [00:00<?, ?batch/s]

Test Accuracy: 30.16%


Epoch   298:   0%|          | 0/33 [00:00<?, ?batch/s]

Test Accuracy: 33.77%


Epoch   299:   0%|          | 0/33 [00:00<?, ?batch/s]

Test Accuracy: 35.36%


Epoch   300:   0%|          | 0/33 [00:00<?, ?batch/s]

Test Accuracy: 34.23%


Epoch   301:   0%|          | 0/33 [00:00<?, ?batch/s]

Test Accuracy: 34.68%


Epoch   302:   0%|          | 0/34 [00:00<?, ?batch/s]

Test Accuracy: 30.85%


Epoch   303:   0%|          | 0/34 [00:00<?, ?batch/s]

Test Accuracy: 34.12%


Epoch   304:   0%|          | 0/34 [00:00<?, ?batch/s]

Test Accuracy: 34.53%


Epoch   305:   0%|          | 0/35 [00:00<?, ?batch/s]

Test Accuracy: 35.13%


Epoch   306:   0%|          | 0/35 [00:00<?, ?batch/s]

Test Accuracy: 34.9%


Epoch   307:   0%|          | 0/35 [00:00<?, ?batch/s]

Test Accuracy: 34.13%


Epoch   308:   0%|          | 0/36 [00:00<?, ?batch/s]

Test Accuracy: 34.11%


Epoch   309:   0%|          | 0/36 [00:00<?, ?batch/s]

Test Accuracy: 29.94%


Epoch   310:   0%|          | 0/36 [00:00<?, ?batch/s]

Test Accuracy: 32.87%


Epoch   311:   0%|          | 0/36 [00:00<?, ?batch/s]

Test Accuracy: 33.84%


Epoch   312:   0%|          | 0/36 [00:00<?, ?batch/s]

Test Accuracy: 32.75%


Epoch   313:   0%|          | 0/36 [00:00<?, ?batch/s]

Test Accuracy: 34.66%


Epoch   314:   0%|          | 0/37 [00:00<?, ?batch/s]

Test Accuracy: 33.63%


Epoch   315:   0%|          | 0/37 [00:00<?, ?batch/s]

Test Accuracy: 35.42%


Epoch   316:   0%|          | 0/37 [00:00<?, ?batch/s]

Test Accuracy: 34.23%


Epoch   317:   0%|          | 0/38 [00:00<?, ?batch/s]

Test Accuracy: 32.96%


Epoch   318:   0%|          | 0/38 [00:00<?, ?batch/s]

Test Accuracy: 32.82%


Epoch   319:   0%|          | 0/38 [00:00<?, ?batch/s]

Test Accuracy: 34.09%


Epoch   320:   0%|          | 0/39 [00:00<?, ?batch/s]

Test Accuracy: 34.64%


Epoch   321:   0%|          | 0/39 [00:00<?, ?batch/s]

Test Accuracy: 34.15%


Epoch   322:   0%|          | 0/39 [00:00<?, ?batch/s]

Test Accuracy: 30.55%


Epoch   323:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 32.9%


Epoch   324:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 36.44%


Epoch   325:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 35.11%


Epoch   326:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 35.45%


Epoch   327:   0%|          | 0/41 [00:00<?, ?batch/s]

Test Accuracy: 30.08%


Epoch   328:   0%|          | 0/41 [00:00<?, ?batch/s]

Test Accuracy: 32.08%


Epoch   329:   0%|          | 0/42 [00:00<?, ?batch/s]

Test Accuracy: 33.37%


Epoch   330:   0%|          | 0/42 [00:00<?, ?batch/s]

Test Accuracy: 30.6%


Epoch   331:   0%|          | 0/43 [00:00<?, ?batch/s]

Test Accuracy: 33.13%


Epoch   332:   0%|          | 0/43 [00:00<?, ?batch/s]

Test Accuracy: 34.55%


Epoch   333:   0%|          | 0/43 [00:00<?, ?batch/s]

Test Accuracy: 33.41%


Epoch   334:   0%|          | 0/43 [00:00<?, ?batch/s]

Test Accuracy: 33.08%


Epoch   335:   0%|          | 0/44 [00:00<?, ?batch/s]

Test Accuracy: 30.37%


Epoch   336:   0%|          | 0/44 [00:00<?, ?batch/s]

Test Accuracy: 31.81%


Epoch   337:   0%|          | 0/45 [00:00<?, ?batch/s]

Test Accuracy: 35.56%


Epoch   338:   0%|          | 0/45 [00:00<?, ?batch/s]

Test Accuracy: 34.02%


Epoch   339:   0%|          | 0/46 [00:00<?, ?batch/s]

Test Accuracy: 33.58%


Epoch   340:   0%|          | 0/46 [00:00<?, ?batch/s]

Test Accuracy: 33.7%


Epoch   341:   0%|          | 0/47 [00:00<?, ?batch/s]

Test Accuracy: 30.39%


Epoch   342:   0%|          | 0/47 [00:00<?, ?batch/s]

Test Accuracy: 34.09%


Epoch   343:   0%|          | 0/47 [00:00<?, ?batch/s]

Test Accuracy: 33.01%


Epoch   344:   0%|          | 0/47 [00:00<?, ?batch/s]

Test Accuracy: 33.66%


Epoch   345:   0%|          | 0/48 [00:00<?, ?batch/s]

Test Accuracy: 34.86%


Epoch   346:   0%|          | 0/48 [00:00<?, ?batch/s]

Test Accuracy: 31.73%


Epoch   347:   0%|          | 0/49 [00:00<?, ?batch/s]

Test Accuracy: 31.81%


Epoch   348:   0%|          | 0/49 [00:00<?, ?batch/s]

Test Accuracy: 33.09%


Epoch   349:   0%|          | 0/50 [00:00<?, ?batch/s]

Test Accuracy: 31.38%


Epoch   350:   0%|          | 0/50 [00:00<?, ?batch/s]

Test Accuracy: 32.62%


Epoch   351:   0%|          | 0/50 [00:00<?, ?batch/s]

Test Accuracy: 31.55%


Epoch   352:   0%|          | 0/50 [00:00<?, ?batch/s]

Test Accuracy: 31.05%


Epoch   353:   0%|          | 0/51 [00:00<?, ?batch/s]

Test Accuracy: 34.6%


Epoch   354:   0%|          | 0/51 [00:00<?, ?batch/s]

Test Accuracy: 35.33%


Epoch   355:   0%|          | 0/52 [00:00<?, ?batch/s]

Test Accuracy: 31.96%


Epoch   356:   0%|          | 0/52 [00:00<?, ?batch/s]

Test Accuracy: 34.02%


Epoch   357:   0%|          | 0/53 [00:00<?, ?batch/s]

Test Accuracy: 32.05%


Epoch   358:   0%|          | 0/53 [00:00<?, ?batch/s]

Test Accuracy: 30.77%


Epoch   359:   0%|          | 0/54 [00:00<?, ?batch/s]

Test Accuracy: 32.27%


Epoch   360:   0%|          | 0/54 [00:00<?, ?batch/s]

Test Accuracy: 32.47%


Epoch   361:   0%|          | 0/54 [00:00<?, ?batch/s]

Test Accuracy: 33.91%


Epoch   362:   0%|          | 0/54 [00:00<?, ?batch/s]

Test Accuracy: 33.38%


Epoch   363:   0%|          | 0/55 [00:00<?, ?batch/s]

Test Accuracy: 32.57%


Epoch   364:   0%|          | 0/55 [00:00<?, ?batch/s]

Test Accuracy: 32.52%


Epoch   365:   0%|          | 0/56 [00:00<?, ?batch/s]

Test Accuracy: 31.34%


Epoch   366:   0%|          | 0/56 [00:00<?, ?batch/s]

Test Accuracy: 30.19%


Epoch   367:   0%|          | 0/57 [00:00<?, ?batch/s]

Test Accuracy: 33.45%


Epoch   368:   0%|          | 0/57 [00:00<?, ?batch/s]

Test Accuracy: 29.66%


Epoch   369:   0%|          | 0/58 [00:00<?, ?batch/s]

Test Accuracy: 34.1%


Epoch   370:   0%|          | 0/58 [00:00<?, ?batch/s]

Test Accuracy: 23.55%


Epoch   371:   0%|          | 0/58 [00:00<?, ?batch/s]

Test Accuracy: 30.07%


Epoch   372:   0%|          | 0/58 [00:00<?, ?batch/s]

Test Accuracy: 33.6%


Epoch   373:   0%|          | 0/59 [00:00<?, ?batch/s]

Test Accuracy: 34.51%


Epoch   374:   0%|          | 0/59 [00:00<?, ?batch/s]

Test Accuracy: 34.44%


Epoch   375:   0%|          | 0/60 [00:00<?, ?batch/s]

Test Accuracy: 34.17%


Epoch   376:   0%|          | 0/60 [00:00<?, ?batch/s]

Test Accuracy: 33.64%


Epoch   377:   0%|          | 0/61 [00:00<?, ?batch/s]

Test Accuracy: 27.6%


Epoch   378:   0%|          | 0/61 [00:00<?, ?batch/s]

Test Accuracy: 31.65%


Epoch   379:   0%|          | 0/61 [00:00<?, ?batch/s]

Test Accuracy: 29.6%


Epoch   380:   0%|          | 0/61 [00:00<?, ?batch/s]

Test Accuracy: 31.16%


Epoch   381:   0%|          | 0/62 [00:00<?, ?batch/s]

Test Accuracy: 33.71%


Epoch   382:   0%|          | 0/62 [00:00<?, ?batch/s]

Test Accuracy: 29.34%


Epoch   383:   0%|          | 0/63 [00:00<?, ?batch/s]

Test Accuracy: 32.49%


Epoch   384:   0%|          | 0/63 [00:00<?, ?batch/s]

Test Accuracy: 30.04%


Epoch   385:   0%|          | 0/64 [00:00<?, ?batch/s]

Test Accuracy: 28.96%


Epoch   386:   0%|          | 0/64 [00:00<?, ?batch/s]

Test Accuracy: 34.11%


Epoch   387:   0%|          | 0/65 [00:00<?, ?batch/s]

Test Accuracy: 33.83%


Epoch   388:   0%|          | 0/65 [00:00<?, ?batch/s]

Test Accuracy: 31.51%


Epoch   389:   0%|          | 0/65 [00:00<?, ?batch/s]

Test Accuracy: 31.85%


Epoch   390:   0%|          | 0/65 [00:00<?, ?batch/s]

Test Accuracy: 30.93%


Epoch   391:   0%|          | 0/66 [00:00<?, ?batch/s]

Test Accuracy: 30.72%


Epoch   392:   0%|          | 0/66 [00:00<?, ?batch/s]

Test Accuracy: 33.23%


Epoch   393:   0%|          | 0/67 [00:00<?, ?batch/s]

Test Accuracy: 32.44%


Epoch   394:   0%|          | 0/67 [00:00<?, ?batch/s]

Test Accuracy: 27.45%


Epoch   395:   0%|          | 0/68 [00:00<?, ?batch/s]

Test Accuracy: 31.93%


Epoch   396:   0%|          | 0/68 [00:00<?, ?batch/s]

Test Accuracy: 31.73%


Epoch   397:   0%|          | 0/68 [00:00<?, ?batch/s]

Test Accuracy: 32.22%


Epoch   398:   0%|          | 0/68 [00:00<?, ?batch/s]

Test Accuracy: 33.69%


Epoch   399:   0%|          | 0/69 [00:00<?, ?batch/s]

Test Accuracy: 32.56%


Epoch   400:   0%|          | 0/69 [00:00<?, ?batch/s]

Test Accuracy: 32.83%


Epoch   401:   0%|          | 0/70 [00:00<?, ?batch/s]

Test Accuracy: 32.56%


Epoch   402:   0%|          | 0/70 [00:00<?, ?batch/s]

Test Accuracy: 30.44%


Epoch   403:   0%|          | 0/71 [00:00<?, ?batch/s]

Test Accuracy: 30.85%


Epoch   404:   0%|          | 0/71 [00:00<?, ?batch/s]

Test Accuracy: 32.86%


Epoch   405:   0%|          | 0/72 [00:00<?, ?batch/s]

Test Accuracy: 30.28%


Epoch   406:   0%|          | 0/72 [00:00<?, ?batch/s]

Test Accuracy: 32.28%


Epoch   407:   0%|          | 0/72 [00:00<?, ?batch/s]

Test Accuracy: 31.37%


Epoch   408:   0%|          | 0/72 [00:00<?, ?batch/s]

Test Accuracy: 32.27%


Epoch   409:   0%|          | 0/73 [00:00<?, ?batch/s]

Test Accuracy: 32.36%


Epoch   410:   0%|          | 0/73 [00:00<?, ?batch/s]

Test Accuracy: 33.72%


Epoch   411:   0%|          | 0/74 [00:00<?, ?batch/s]

Test Accuracy: 31.26%


Epoch   412:   0%|          | 0/74 [00:00<?, ?batch/s]

Test Accuracy: 32.17%


Epoch   413:   0%|          | 0/75 [00:00<?, ?batch/s]

Test Accuracy: 30.23%


Epoch   414:   0%|          | 0/75 [00:00<?, ?batch/s]

Test Accuracy: 32.67%


Epoch   415:   0%|          | 0/75 [00:00<?, ?batch/s]

Test Accuracy: 31.79%


Epoch   416:   0%|          | 0/75 [00:00<?, ?batch/s]

Test Accuracy: 31.64%


Epoch   417:   0%|          | 0/76 [00:00<?, ?batch/s]

Test Accuracy: 32.99%


Epoch   418:   0%|          | 0/76 [00:00<?, ?batch/s]

Test Accuracy: 32.82%


Epoch   419:   0%|          | 0/77 [00:00<?, ?batch/s]

Test Accuracy: 33.04%


Epoch   420:   0%|          | 0/77 [00:00<?, ?batch/s]

Test Accuracy: 31.29%


Epoch   421:   0%|          | 0/78 [00:00<?, ?batch/s]

Test Accuracy: 29.62%


Epoch   422:   0%|          | 0/78 [00:00<?, ?batch/s]

Test Accuracy: 33.16%


Epoch   423:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.5%


Epoch   424:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.41%


Epoch   425:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 27.1%


Epoch   426:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.65%


Epoch   427:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.16%


Epoch   428:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.23%


Epoch   429:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.87%


Epoch   430:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.21%


Epoch   431:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.21%


Epoch   432:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.39%


Epoch   433:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.02%


Epoch   434:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.19%


Epoch   435:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 29.93%


Epoch   436:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.01%


Epoch   437:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.63%


Epoch   438:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.77%


Epoch   439:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.13%


Epoch   440:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.76%


Epoch   441:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.37%


Epoch   442:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.61%


Epoch   443:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.63%


Epoch   444:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.15%


Epoch   445:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.3%


Epoch   446:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 35.28%


Epoch   447:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.63%


Epoch   448:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.17%


Epoch   449:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.44%


Epoch   450:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 34.2%


Epoch   451:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 27.67%


Epoch   452:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 35.32%


Epoch   453:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.76%


Epoch   454:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.73%


Epoch   455:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.83%


Epoch   456:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.07%


Epoch   457:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.01%


Epoch   458:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.53%


Epoch   459:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 26.37%


Epoch   460:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.08%


Epoch   461:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.62%


Epoch   462:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.33%


Epoch   463:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.81%


Epoch   464:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.81%


Epoch   465:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.85%


Epoch   466:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.36%


Epoch   467:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.77%


Epoch   468:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 29.47%


Epoch   469:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.46%


Epoch   470:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.11%


Epoch   471:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.17%


Epoch   472:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.98%


Epoch   473:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.05%


Epoch   474:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.94%


Epoch   475:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 34.0%


Epoch   476:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.28%


Epoch   477:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.76%


Epoch   478:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.51%


Epoch   479:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.0%


Epoch   480:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.28%


Epoch   481:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.42%


Epoch   482:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.67%


Epoch   483:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.27%


Epoch   484:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.66%


Epoch   485:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.02%


Epoch   486:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.47%


Epoch   487:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.31%


Epoch   488:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.73%


Epoch   489:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.57%


Epoch   490:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 29.92%


Epoch   491:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.12%


Epoch   492:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.1%


Epoch   493:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.74%


Epoch   494:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.22%


Epoch   495:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 29.23%


Epoch   496:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.18%


Epoch   497:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.12%


Epoch   498:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 29.64%


Epoch   499:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.28%


Epoch   500:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.6%


Epoch   501:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.75%


Epoch   502:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.87%


Epoch   503:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.48%


Epoch   504:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.27%


Epoch   505:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.35%


Epoch   506:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.96%


Epoch   507:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.73%


Epoch   508:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.13%


Epoch   509:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.08%


Epoch   510:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.98%


Epoch   511:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.43%


Epoch   512:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.01%


Epoch   513:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 27.83%


Epoch   514:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.21%


Epoch   515:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.53%


Epoch   516:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 32.92%


Epoch   517:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.68%


Epoch   518:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 31.97%


Epoch   519:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.62%


Epoch   520:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 28.18%


Epoch   521:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 25.93%


Epoch   522:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 30.28%


Epoch   523:   0%|          | 0/79 [00:00<?, ?batch/s]

Test Accuracy: 33.33%


Epoch   524:   0%|          | 0/79 [00:00<?, ?batch/s]

In [None]:
trainset_sup_new = create_labeled_dataset_active_learning(trainset_unsup, selected_indices)

In [None]:
trainset_sup_new[-1]

In [None]:
plt.imshow(trainset_sup_new[-1][0].permute(1, 2, 0).cpu().numpy())
plt.title(f'Pred: {trainset_sup_new[-1][1]}')
plt.show()

plt.imshow(trainset_sup[-1][0].permute(1, 2, 0).cpu().numpy())
plt.title(f'Pred: {trainset_sup[-1][1]}')
plt.show()

In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_10_fixmatch_AL.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(TARGET_PROP*100)}% - {test_accuracy:.2f}% - Active Learning")
fig1.savefig(f"./figures/test_score_{TARGET_PROP}_fixmatch_AL.png")

### IV.2 Fixmatch on 5% train data with Active Learning

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

TARGET_PROP = 0.05
EPOCHS = 300
SUBSET_PROP = 0.005
EPOCHS_AL = 50
K_SAMPLES = 50

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = K.ImageSequential(
    K.RandomHorizontalFlip(p=0.50), 
    K.RandomAffine(degrees=0, translate=(0.125, 0.125)),
)

strong_transform = K.ImageSequential(
    K.auto.RandAugment(n=2, m=10), # randaugment + cutout
)

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


# Define the cosine learning rate decay function
lr_lambda = lambda step: LR * torch.cos(torch.tensor((7 * torch.pi * (step)) / (16 * EPOCHS))) * 100 / 3

# Create a learning rate scheduler with the cosine decay function
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# scheduler = None

In [None]:
print("Start training")

current_prop = SUBSET_PROP

train_losses = []
test_losses = []
added_samp = 0
uncertainty = 0

for j in range(EPOCHS):
    model.train()

    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0

    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {j: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, unlabeled_labels = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]
        unlabeled_labels = unlabeled_labels[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)
            
            true_unlabeled_loss = torch.sum(true_unlabeled_criterion(unlabeled_outputs, unlabeled_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)
            true_unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()
            

        # backward pass + optimize
        loss.backward()

        # clamp gradients, just in case
        # for p in filter(lambda p: p.grad is not None, model.parameters()): p.grad.data.clamp_(min=-.1, max=.1)
        
        optimizer.step()

        
        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled 

        # update progress bar
        pbar.set_postfix({
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "true unlabeled loss": true_unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg confidence": torch.mean(max_qb).item(),
            "n_unlabeled": running_n_unlabeled,
            "current_prop": current_prop,
            "uncertainty": uncertainty,
            "lr": optimizer.param_groups[0]['lr']
        })

    # start adding labels after 50 epochs for 50 epochs
    if j >= 50:
        if current_prop < TARGET_PROP:
            # compute information density
            selected_indices, uncertainty = least_confidence(model, unlabeled_dataloader, K_SAMPLES)
            print(selected_indices)

            # select indices from unlabeled dataset
            trainset_sup_new = create_labeled_dataset_active_learning(trainset_unsup, selected_indices)

            # concat new trainset with labeled trainset
            trainset_sup = torch.utils.data.ConcatDataset([trainset_sup, trainset_sup_new])

            # create labeled dataloader
            labeled_dataloader = torch.utils.data.DataLoader(trainset_sup, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

            current_prop = len(trainset_sup) / len(trainset)
    

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    # if n_unlabeled != 0:
    #     # plot an image of the batch
    #     image = strong_unlabeled_inputs[0].cpu()
    #     # image = image * std + mean
    #     image = image.permute(1, 2, 0).cpu().numpy() * std + mean
    #     plt.imshow(image)
    #     plt.title(f'Pred: {unlabeled_outputs.argmax(dim=1)[0]}, true: {unlabeled_labels[0]}')
    #     plt.show()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)

In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_5_fixmatch.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}%")
fig1.savefig(f"./figures/test_score_{SUBSET_PROP}_fixmatch.png")

### IV.3 Fixmatch on 1% train data with Active Learning

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

TARGET_PROP = 0.01
EPOCHS = 300
SUBSET_PROP = 0.005
K_SAMPLES = 50

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=False,
    num_workers=0
)
# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = K.ImageSequential(
    K.RandomHorizontalFlip(p=0.50), 
    K.RandomAffine(degrees=0, translate=(0.125, 0.125)),
)

strong_transform = K.ImageSequential(
    K.auto.RandAugment(n=2, m=10), # randaugment + cutout
)

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


# Define the cosine learning rate decay function
lr_lambda = lambda step: LR * torch.cos(torch.tensor((7 * torch.pi * (step)) / (16 * EPOCHS))) * 100 / 3

# Create a learning rate scheduler with the cosine decay function
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# scheduler = None

In [None]:
print("Start training")

current_prop = SUBSET_PROP

train_losses = []
test_losses = []
added_samp = 0
uncertainty = 0

for j in range(EPOCHS):
    model.train()

    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0

    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {j: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, unlabeled_labels = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]
        unlabeled_labels = unlabeled_labels[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)
            
            true_unlabeled_loss = torch.sum(true_unlabeled_criterion(unlabeled_outputs, unlabeled_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)
            true_unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()
            

        # backward pass + optimize
        loss.backward()

        # clamp gradients, just in case
        # for p in filter(lambda p: p.grad is not None, model.parameters()): p.grad.data.clamp_(min=-.1, max=.1)
        
        optimizer.step()

        
        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled 

        # update progress bar
        pbar.set_postfix({
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "true unlabeled loss": true_unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg confidence": torch.mean(max_qb).item(),
            "n_unlabeled": running_n_unlabeled,
            "current_prop": current_prop,
            "uncertainty": uncertainty,
            "lr": optimizer.param_groups[0]['lr']
        })

    # start adding labels after 50 epochs for 50 epochs
    if j >= 50:
        if current_prop < TARGET_PROP:
            # compute information density
            selected_indices, uncertainty = least_confidence(model, unlabeled_dataloader, K_SAMPLES)
            print(selected_indices)

            # select indices from unlabeled dataset
            trainset_sup_new = create_labeled_dataset_active_learning(trainset_unsup, selected_indices)

            # concat new trainset with labeled trainset
            trainset_sup = torch.utils.data.ConcatDataset([trainset_sup, trainset_sup_new])

            # create labeled dataloader
            labeled_dataloader = torch.utils.data.DataLoader(trainset_sup, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

            current_prop = len(trainset_sup) / len(trainset)
    

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    # if n_unlabeled != 0:
    #     # plot an image of the batch
    #     image = strong_unlabeled_inputs[0].cpu()
    #     # image = image * std + mean
    #     image = image.permute(1, 2, 0).cpu().numpy() * std + mean
    #     plt.imshow(image)
    #     plt.title(f'Pred: {unlabeled_outputs.argmax(dim=1)[0]}, true: {unlabeled_labels[0]}')
    #     plt.show()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)

In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_1_fixmatch.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}%")
fig1.savefig(f"./figures/test_score_{SUBSET_PROP}_fixmatch.png")