# Projet Fixmatch

In [1]:
# !pip install torchview torchsummary torchvision kornia torchmetrics matplotlib tqdm path graphviz opencv-python scikit-learn optuna

In [1]:
# deep learning
import torch
import torch.nn as nn
from torch.distributions.transforms import LowerCholeskyTransform
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, Dataset

# vizualisation
import torchsummary

# transforms
import torchvision.transforms as T
import kornia.augmentation as K
from kornia.enhance import normalize
# from torchvision.transforms import RandAugment

# metrics
from torchmetrics import Accuracy

# torchvision
import torchvision
from torchvision import transforms

# plotting
import matplotlib.pyplot as plt
from torchview import draw_graph

from IPython.display import display
from IPython.core.display import SVG, HTML

from tqdm.auto import tqdm

# typing
from typing import Callable

from utils import plot_images, plot_transform
from model import ConvNN, display_model

# os
import os
import path

import random
import numpy as np 

# transformations
# import transform as T
# from randaugment import RandomAugment

# typing
from typing import Callable, List, Tuple

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics.pairwise import cosine_similarity

from randaugment import RandAugmentMC

%load_ext autoreload
%autoreload 2

In [2]:
DEFAULT_RANDOM_SEED = 2021

def seedBasic(seed=DEFAULT_RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
# torch random seed
import torch
def seedTorch(seed=DEFAULT_RANDOM_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
# basic + tensorflow + torch 
def seedEverything(seed=DEFAULT_RANDOM_SEED):
    seedBasic(seed)
    seedTorch(seed)

In [3]:
# Set device
if ((int(torch.__version__.split(".")[0]) >= 2) or (int(torch.__version__.split(".")[1]) >= 13)) and torch.has_mps:
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [4]:
IMG_SHAPE = (3, 32, 32)
# See Table 4
TAU = 0.9 #! 0.95 in the paper
LAMBDA_U = 3
MU = 4
BATCH_SIZE = 64
LR = 0.03
BETA = 0.9
WEIGHT_DECAY = 0.0005
BETA_DENSITY = 1

In [5]:
class ConvNN(nn.Module):
    """
    Simple CNN for CIFAR10
    """
    
    def __init__(self):
        super().__init__()
        self.conv_32 = nn.Conv2d(3, 32, kernel_size=3, padding='same')
        self.conv_64 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.conv_96 = nn.Conv2d(64, 96, kernel_size=3, padding='same')
        self.conv_128 = nn.Conv2d(96, 128, kernel_size=3, padding='same')
        self.fc_512 = nn.Linear(512, 512)
        self.fc_10 = nn.Linear(512, 10)
        self.max_pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU(inplace=True)
        self.flatten = nn.Flatten()
        # self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_32(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_64(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_96(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_128(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.flatten(x)
        x = self.fc_512(x)
        x = self.relu(x)
        x = self.fc_10(x)
        # x = self.softmax(x)

        return x

In [6]:
def compute_mean_std(trainLoader) -> Tuple[List[float], List[float]]:
    # initialize the list of means and stds
    mean, std = torch.zeros(3), torch.zeros(3)

    # iterate over the dataset and compute the sum of each channel
    for images, _ in trainLoader:
        mean+= torch.mean(images, dim=[0,2,3])
        std+= torch.std(images, dim=[0,2,3])
    
    # compute the mean and std
    mean = mean/len(trainLoader)
    std = std/len(trainLoader)

    return mean, std

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

if not os.path.exists('./data/mean.pt'):
    mean, std = compute_mean_std(trainloader)
    torch.save(mean, 'data/mean.pt')
    torch.save(std, 'data/std.pt')
else:
    mean, std = torch.load('./data/mean.pt'), torch.load('./data/std.pt')

# to numpy
mean, std = mean.numpy(), std.numpy()

print(f"mean: {mean}, std: {std}")


testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
mean: [0.4913966  0.48215377 0.44651437], std: [0.246344   0.24280126 0.26067406]
Files already downloaded and verified


In [7]:
torch_models = 'torch_models' 
if not os.path.exists(torch_models):
    os.makedirs(torch_models)

## IV. Semi-Supervised Learning: Fixmatch with Active Learning

### IV.1 Fixmatch on 10% train data with Active Learning

In [8]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

TARGET_PROP = 0.05
EPOCHS = 150
SUBSET_PROP = 0.01
EPOCHS_AL = 50
K_SAMPLES = int ( (TARGET_PROP * len(trainset) - SUBSET_PROP * len(trainset)) / EPOCHS_AL )

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=False,
    num_workers=0
)



In [9]:
# transformations
weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=0, translate=(0.125, 0.125)),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

strong_transform = transforms.Compose([
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomAffine(degrees=10, translate=(0.125, 0.125)),
    # transforms.RandAugment(num_ops=2, magnitude=10),
    RandAugmentMC(n=2, m=10),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
    

In [10]:
def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat_max = torch.max(qb, dim=1)
        

        idx_max = max_qb > TAU
        qb_hat_max = qb_hat_max[idx_max]



    return qb_hat_max.detach(), idx_max, max_qb.detach()

In [11]:
model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')
true_unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=0, last_epoch=-1)

In [12]:
def information_density(
    model: ConvNN, 
    inputs: torch.Tensor,
    k_samp: int,
    K_transform: int = 5) -> torch.Tensor:

    qb_transformed = torch.zeros((K_transform, len(inputs), 10))
    with torch.no_grad():
        model.train()
        for k in range(K_transform):
            # compute weak transformation of inputs norm
            weak_transformed_inputs = strong_transform(inputs)

            # compute predictions of weak transformed inputs norm
            logits = model(weak_transformed_inputs)
            qb = torch.softmax(logits, dim=1)

            # store
            qb_transformed[k] = qb
    
    # compute mean
    qb = torch.mean(qb_transformed, dim=0)

    # get top 2 values of qb
    qb1, qb2 = torch.topk(qb, k=2, dim=1, sorted=True)[0][:, 0], torch.topk(qb, k=2, dim=1, sorted=True)[0][:, 1]

    uncertainty = 1 - (qb1 - qb2)

    # get top k_samp indices
    _, idx = torch.topk(uncertainty, k=k_samp, dim=0, sorted=True)

    return idx, torch.mean(torch.sort(uncertainty, descending=True)[0][:k_samp]).item()

def least_confidence(
    model: ConvNN,
    unlabeled_dataloader: torch.utils.data.DataLoader,
    k_samp: int) -> torch.Tensor:

    # unlabeled_indices = np.arange(len(trainset_unsup))
    
    probs = []

    model.eval()

    with torch.no_grad():
        for i, (inputs, _) in enumerate(unlabeled_dataloader):
            inputs = inputs.to(device)

            # inputs transforms
            weak_transformed_inputs = weak_transform(inputs)

            logits = model(weak_transformed_inputs)
            qb = torch.softmax(logits, dim=1)
            probs.append(qb.cpu().data)
    
    # transform to tensor
    probs_tensor = torch.cat(probs, dim=0)

    U = probs_tensor.max(1)[0]

    selected_indices = U.sort()[1][:k_samp]

    return selected_indices, torch.mean(torch.sort(U, descending=False)[0][:k_samp]).item()


# Create a new labeled dataset using active learning
def create_labeled_dataset_active_learning(dataset, selected_indices):
    dataset = torch.utils.data.Subset(dataset, selected_indices)
    return dataset

In [13]:
print("Start training")

current_prop = SUBSET_PROP

train_losses = []
test_losses = []
added_samp = 0
uncertainty = 0

for j in range(EPOCHS):
    model.train()

    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0

    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {j: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, unlabeled_labels = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]
        unlabeled_labels = unlabeled_labels[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)
            
            true_unlabeled_loss = torch.sum(true_unlabeled_criterion(unlabeled_outputs, unlabeled_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)
            true_unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()
            

        # backward pass + optimize
        loss.backward()

        # clamp gradients, just in case
        # for p in filter(lambda p: p.grad is not None, model.parameters()): p.grad.data.clamp_(min=-.1, max=.1)
        
        optimizer.step()

        
        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled 

        # update progress bar
        pbar.set_postfix({
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "true unlabeled loss": true_unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg confidence": torch.mean(max_qb).item(),
            "n_unlabeled": running_n_unlabeled,
            "current_prop": current_prop,
            "uncertainty": uncertainty,
            "lr": optimizer.param_groups[0]['lr']
        })

    # start adding labels after 50 epochs for 50 epochs
    if j >= 50:
        if current_prop < TARGET_PROP:
            # compute information density
            selected_indices, uncertainty = least_confidence(model, unlabeled_dataloader, K_SAMPLES)
            print(selected_indices)

            # select indices from unlabeled dataset
            trainset_sup_new = create_labeled_dataset_active_learning(trainset_unsup, selected_indices)

            # concat new trainset with labeled trainset
            trainset_sup = torch.utils.data.ConcatDataset([trainset_sup, trainset_sup_new])

            # create labeled dataloader
            labeled_dataloader = torch.utils.data.DataLoader(trainset_sup, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

            current_prop = len(trainset_sup) / len(trainset)
    

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    # if n_unlabeled != 0:
    #     # plot an image of the batch
    #     image = strong_unlabeled_inputs[0].cpu()
    #     # image = image * std + mean
    #     image = image.permute(1, 2, 0).cpu().numpy() * std + mean
    #     plt.imshow(image)
    #     plt.title(f'Pred: {unlabeled_outputs.argmax(dim=1)[0]}, true: {unlabeled_labels[0]}')
    #     plt.show()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)

Start training


Epoch     0:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     1:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     2:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     3:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 12.81%


Epoch     4:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 20.29%


Epoch     5:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 20.0%


Epoch     6:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 21.13%


Epoch     7:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 23.41%


Epoch     8:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 25.77%


Epoch     9:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 24.94%


Epoch    10:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 26.47%


Epoch    11:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 26.0%


Epoch    12:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 29.7%


Epoch    13:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 25.73%


Epoch    14:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.19%


Epoch    15:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 29.13%


Epoch    16:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 28.1%


Epoch    17:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 33.3%


Epoch    18:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 30.67%


Epoch    19:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 33.01%


Epoch    20:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 29.59%


Epoch    21:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.35%


Epoch    22:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 32.68%


Epoch    23:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 36.38%


Epoch    24:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 34.82%


Epoch    25:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.15%


Epoch    26:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.37%


Epoch    27:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.95%


Epoch    28:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.66%


Epoch    29:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.28%


Epoch    30:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.76%


Epoch    31:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.89%


Epoch    32:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 36.4%


Epoch    33:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.3%


Epoch    34:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.24%


Epoch    35:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.71%


Epoch    36:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.01%


Epoch    37:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.02%


Epoch    38:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.17%


Epoch    39:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 36.96%


Epoch    40:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.1%


Epoch    41:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.77%


Epoch    42:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.7%


Epoch    43:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.53%


Epoch    44:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 36.44%


Epoch    45:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.91%


Epoch    46:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.62%


Epoch    47:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.94%


Epoch    48:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.88%


Epoch    49:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.46%


Epoch    50:   0%|          | 0/8 [00:00<?, ?batch/s]

tensor([31089, 18962, 22213, 34052, 33721, 30967, 12121, 20486, 30543, 18105,
        38959, 42567, 39634,  2911, 44137,  9799, 14371, 31345, 38874, 19534,
        46120,  6460, 39644, 23118, 22661, 26420, 10912, 11214, 35768, 35102,
        45149, 13048, 14645, 47195, 13338, 45126, 45636, 44530, 43037, 31761])
Test Accuracy: 41.54%


Epoch    51:   0%|          | 0/9 [00:00<?, ?batch/s]

tensor([39201, 36105, 13551, 37665, 24511,  9396, 22227,  3237, 21891, 25698,
         6004, 17593, 36306,  2349, 13389, 46052,  7265, 34361, 34019, 28894,
        22702, 16142,  5283,  9875, 28652, 22024, 11709,  4518, 44824, 44137,
          492, 11902,  6093,  8611, 15377, 37164,  7699,  6566, 39906, 13552])
Test Accuracy: 37.95%


Epoch    52:   0%|          | 0/10 [00:00<?, ?batch/s]

tensor([18005, 24450, 26216, 12025, 21577, 13502,  8465, 30046, 40380, 22995,
          985,  3217, 33095, 15649,   111, 41750, 25662,   148, 37819, 41824,
        39722, 32953, 43742, 25293, 17145, 13634, 32195, 20794, 46708, 14875,
        40916, 12742, 22875, 45080, 39002,  8487, 15633, 17542, 39862,  9816])
Test Accuracy: 34.66%


Epoch    53:   0%|          | 0/10 [00:00<?, ?batch/s]

tensor([21357, 46429, 44650, 14875, 11483, 32562,  4825,  7476, 16802,  7078,
         7751, 25738, 26096, 37891, 16476, 10862, 42580, 40171, 47671, 35835,
        17168,  5212, 35360, 38201, 13858,  6819,  8386,  8293, 13874,   108,
        18319, 23929,  9378,  8933,  3858, 27226,  2492, 11019, 28765, 44457])
Test Accuracy: 41.51%


Epoch    54:   0%|          | 0/11 [00:00<?, ?batch/s]

tensor([27369, 20359,  4825, 25644, 44457, 34050, 42312,  5246, 13565,  4740,
        24841, 16996,  9346, 14468,  4543, 47591, 11705, 15103, 48719, 11483,
         3229, 45541, 20657, 17220, 36788, 38554, 27251, 36669, 26907, 39692,
        35835,  7941, 28971,   501, 12892, 43676, 46796, 25627, 44141, 21123])
Test Accuracy: 41.16%


Epoch    55:   0%|          | 0/11 [00:00<?, ?batch/s]

tensor([37768,  4212, 30022, 31976, 10816, 41818, 22998,  3100, 28765,  6804,
        43500, 43783, 21596, 23803, 46668, 40431, 14810, 27255, 26287, 31067,
        47782, 43226,  7901,  3222, 41636, 16545, 35089,  2240, 17311, 43681,
        37052,  7014, 41133,  3833, 42357, 47310, 48703, 26319,  8685, 11483])
Test Accuracy: 41.92%


Epoch    56:   0%|          | 0/12 [00:00<?, ?batch/s]

tensor([11641, 26002, 23267, 13690, 26758, 47683, 45268, 37642, 42357, 36583,
        18304, 39721, 31761, 12687, 32514,  6158, 31801, 23551, 36046, 22382,
        36452, 43355, 10579,  4000, 34992, 36661, 16667, 35642, 27860, 11606,
         6496, 12121, 47080,  4192, 44078, 48703, 10816, 36574, 28702,  2004])
Test Accuracy: 39.11%


Epoch    57:   0%|          | 0/13 [00:00<?, ?batch/s]

tensor([45601,  5537, 29925, 37946,  5370,  5254, 47559,  3635, 30543, 16192,
        20359, 20168, 49209, 33415, 34331, 10309, 38724, 18575, 16933, 24917,
        30905, 31345, 39397, 31226,  6236,  6425, 46221,  1129, 10065, 48703,
         5169,  9862,  1744, 49443, 27078, 17213, 27407, 14854, 43577, 35089])
Test Accuracy: 43.44%


Epoch    58:   0%|          | 0/13 [00:00<?, ?batch/s]

tensor([ 4887,  8188, 23788, 49011, 43398, 11763, 14905, 22314, 36830,  9387,
        36681, 41788, 16732,  5519,  9335, 14741, 46653, 46772, 19910, 45590,
        18679, 22419, 10822,  1435, 25647, 30804, 42355, 37045, 32919, 19338,
        47856, 17488,  6503, 33509, 16100, 26555, 25666, 10680, 41446, 39298])
Test Accuracy: 36.07%


Epoch    59:   0%|          | 0/14 [00:00<?, ?batch/s]

tensor([31181, 25420,   896, 14394, 40202,  8246, 38651, 44070, 31335, 38276,
        36775,  1034, 13340, 37814,  8417, 49854, 24562, 14712, 23880, 36574,
         4331, 10313, 47219, 14490, 37875, 49146, 16687, 44506, 11214, 11094,
         5586, 38292, 30804, 38388, 48909,  4203, 17168, 15103, 33606, 34272])
Test Accuracy: 40.42%


Epoch    60:   0%|          | 0/15 [00:00<?, ?batch/s]

tensor([30016, 34557, 20622, 33606, 18586,  7374,  9798, 10758, 18319, 40451,
        43923, 24445, 25632, 38292, 16942, 18770, 42984, 12594, 27226, 11612,
        39722, 49809, 45968,  7905,  4664, 45238,   233,  9258, 15103, 47541,
        25172, 22770, 45401,  9933, 38889, 27712, 20115, 34885, 49026, 37221])
Test Accuracy: 40.69%


Epoch    61:   0%|          | 0/15 [00:00<?, ?batch/s]

tensor([16996, 41818, 43767, 12787,  1026, 23871, 47344, 28549, 39292, 15382,
         7867, 28560, 10951, 14306, 49578, 13879, 29034, 35369, 37023, 15559,
        26506, 40212, 28947, 32769, 36563, 36567, 37364, 28528, 28856, 26902,
         4332, 43586, 49713,  7990, 39204, 18708, 46893, 32505,  6023, 22038])
Test Accuracy: 42.07%


Epoch    62:   0%|          | 0/16 [00:00<?, ?batch/s]

tensor([44716, 28797,   451, 11377,  4393, 46522, 28640,  4498, 37272, 18824,
        11492, 12769, 11700, 46384, 33952,  5001, 39834,  2467,  2165, 29239,
        29903, 18556, 41556, 28135,   840, 12381, 21363, 31797, 16115, 11581,
        33518, 15762,  2284, 38609, 22798, 27369, 31190, 32250,  9734, 13536])
Test Accuracy: 40.81%


Epoch    63:   0%|          | 0/16 [00:00<?, ?batch/s]

tensor([39204, 18473, 11933, 14159, 15362, 15493, 32790, 10584, 11345,  1742,
        32863, 30327,  6101, 29117,  2911, 48991, 20841, 33888, 19960, 39297,
        49209,  9303, 19864, 16149,  3885,  1557, 14277, 27529,  5626,  4795,
        11469, 30107, 17474, 49291, 49404,  2507, 10614, 35108, 40163,  8032])
Test Accuracy: 44.01%


Epoch    64:   0%|          | 0/17 [00:00<?, ?batch/s]

tensor([43767,  9658, 24018, 28199,  4199,  7905,  6101, 10654,  4825, 42021,
        35089, 34195, 45581, 22768, 23715,   825, 30680, 32130, 17690, 19653,
        12864, 42179,  6371, 29649, 40643, 31769, 27019, 27214, 32871,  6885,
         9116,  1744, 31739, 33385, 34017,  3303, 19942, 33312, 19340, 19920])
Test Accuracy: 44.71%


Epoch    65:   0%|          | 0/18 [00:00<?, ?batch/s]

tensor([43876, 22398, 43341, 47109, 17843, 14159, 15761, 10209, 29583, 17146,
        22313,  6384, 35741, 40431, 19000,  4657, 40338, 21539, 13974,  7616,
        18780, 38276, 45424, 33162, 39634, 14854, 47346, 40705, 43681, 47197,
        46429, 32253, 45330, 14091, 49444, 30022, 37549,  6122, 31465, 29893])
Test Accuracy: 45.83%


Epoch    66:   0%|          | 0/18 [00:00<?, ?batch/s]

tensor([39778, 26730, 33551,  1359, 37344, 19949,  2579,  2971, 46909, 35642,
        30691,  7905, 29583, 11181, 19000, 30421, 15907, 14616, 46575, 14415,
        15607, 17882, 15478, 26069, 31769,  2317, 34342, 22439, 45607, 42818,
        31122, 22658, 37768, 22768, 38590,  9935,  9891, 23534,  8410, 35202])
Test Accuracy: 43.79%


Epoch    67:   0%|          | 0/19 [00:00<?, ?batch/s]

tensor([11456, 19704, 40271, 12697, 36090, 40431,  3452,  8267, 25131, 22037,
        46284, 20636, 43015, 20179, 26289, 13858, 28196, 14415, 11181, 18105,
         4332, 44805,  5963, 25009,  7997, 38634, 48789, 42098, 12786, 13539,
        33292, 28190, 30620, 43797, 12705, 10546, 35642, 28273,  3987, 41632])
Test Accuracy: 45.86%


Epoch    68:   0%|          | 0/20 [00:00<?, ?batch/s]

tensor([ 7424,  3823,  1783,  9611, 33415, 47782, 36695, 46807, 31455, 17902,
        27031, 35958, 19174, 40396, 11854,  4654, 28659, 25673,  5443, 44600,
        15331, 16192, 48637, 46183, 10639, 38047, 33644,  3346, 20908, 31004,
        48271, 30734, 12449, 40356, 15531, 12953, 46958, 37888, 22509,  1331])
Test Accuracy: 45.11%


Epoch    69:   0%|          | 0/20 [00:00<?, ?batch/s]

tensor([47782, 38731, 37693, 36899, 21175, 11902, 20925, 39437, 15690, 30470,
        42406,  1653,   502, 35243, 43876, 40431,  9528, 26419,  5700, 22055,
         5625, 33360, 14019,  7763, 10346, 22337, 27221, 19248,  5018,  9662,
          497, 41105,  8246,  7747, 19008, 41500, 16448, 18135, 16095, 42683])
Test Accuracy: 45.94%


Epoch    70:   0%|          | 0/21 [00:00<?, ?batch/s]

tensor([27494, 37076, 12787, 44005, 49713,  1055, 29903,  6374, 35545, 38564,
        47452, 34487,  1457, 10951,  4332, 27221, 25982, 21856, 37922, 13206,
          825, 33724,  2667, 48703, 36304, 17098, 30338, 14992, 49992, 30136,
         8746, 35325, 39121, 24883, 40836, 26926,  1889, 48473, 16333, 27407])
Test Accuracy: 45.31%


Epoch    71:   0%|          | 0/21 [00:00<?, ?batch/s]

tensor([23661, 11469, 39714, 40714, 25216, 10261, 20490, 44282, 16476,  3358,
        31079, 25399, 10951, 19534, 30972, 47152, 16812, 23257, 35897,  9541,
        42040, 37443, 19520, 28791, 21410, 30827, 48295, 29023, 29297, 47834,
         5855, 13332, 45892, 48264,  6023,  6966, 34152, 31565,  2511, 44271])
Test Accuracy: 48.4%


Epoch    72:   0%|          | 0/22 [00:00<?, ?batch/s]

tensor([27329,   949, 14749, 43396, 22350, 40989,  7936, 18526,  1026, 40241,
        30495, 17318, 28199, 44852, 36441, 22545, 46677, 46267, 31992, 42289,
        22651, 30312, 22459, 34226, 49387, 20343, 40572, 46612,  5370, 47421,
         9061, 14388,  6122, 13244, 27226, 36955, 47364, 38743, 11742, 43923])
Test Accuracy: 45.59%


Epoch    73:   0%|          | 0/23 [00:00<?, ?batch/s]

tensor([  409, 26439, 32578, 39899, 27255, 46036, 31489,  6352, 46070, 18931,
        19635, 40774, 38681,  2517, 12092,  1384, 38134, 19417, 39753, 22509,
        24841, 43073, 44890, 47270,  6588, 33378, 18279, 29951,  5317, 49964,
        12334, 25427,  4096, 30885, 49208,  2284, 16610, 48012,  4100, 30165])
Test Accuracy: 45.63%


Epoch    74:   0%|          | 0/23 [00:00<?, ?batch/s]

tensor([31079, 11670, 33821,  3637, 21369, 22378,  4631, 48264,  6782, 34503,
        21249, 44517, 36872, 15742, 31864, 41183,   991, 17370, 14201, 14240,
         8367, 47337, 30505,  8225,  2021, 37325, 48422, 19508, 39437, 39402,
        43186,  9885,  1130, 30859, 13206, 38378, 44877,  9061, 30260, 46608])
Test Accuracy: 48.99%


Epoch    75:   0%|          | 0/24 [00:00<?, ?batch/s]

tensor([40030, 32243,  3950, 39121, 14284, 22554, 48957, 11214, 23871, 35107,
        31992,  6489,   932, 22768, 44434, 49477, 20794,  7490, 37886, 14966,
          169, 22977, 14619, 14769,  9885, 36442,  4445, 15803, 36901, 45851,
        16545, 40781, 44470, 43190, 46284, 23527, 46128, 22710, 15891, 14491])
Test Accuracy: 44.27%


Epoch    76:   0%|          | 0/25 [00:00<?, ?batch/s]

tensor([ 3293, 34833, 27682, 18526, 40692, 37557, 46036, 22652, 29356, 45671,
         3395,  3744, 31933, 12831, 11410,  6023,  2664,  7370,  6259, 27919,
        45636, 29201, 22840, 13038, 11527,   825, 13389, 34487,  8926, 39915,
         8041, 10731, 10758,  4631, 49580, 40529, 37124, 46313,  6799,  4885])
Test Accuracy: 47.3%


Epoch    77:   0%|          | 0/25 [00:00<?, ?batch/s]

tensor([49845,  8942,  3637, 13340, 17164, 41952, 48725, 20991, 42289, 16837,
        19534, 33716,  8367, 23826, 41525,  7639,  9708, 16732, 14454, 26748,
         9434, 49536, 19017, 12305, 24290, 47749,  1004,  4560, 20059,  4689,
        33376, 47355, 12836, 18418, 27360, 49404, 47297, 28672, 26419, 33765])
Test Accuracy: 48.73%


Epoch    78:   0%|          | 0/26 [00:00<?, ?batch/s]

tensor([ 4014, 38990, 30258, 46429, 34064,  9215, 49963,  4113,  3637,  2667,
        33187, 42305, 35824, 32643, 48649, 14202, 13257, 26439, 41032, 35658,
        27221, 18129, 36084, 49928, 37076, 36901, 28719,  4662,  2664, 39882,
        42765,   825, 11583, 46897, 31770, 29940, 20298, 34709, 31032,   722])
Test Accuracy: 47.71%


Epoch    79:   0%|          | 0/26 [00:00<?, ?batch/s]

tensor([10855, 16601, 12305, 31933, 18327, 12396, 45161, 37662, 17617, 24511,
         5370, 10118, 44160, 28641, 43820, 39437, 13678,  5700,  5614, 45352,
        45074, 14479, 24991,  7310,  5401, 13107, 20018,  2678, 21788, 15289,
         7370, 39505, 21196, 25086, 30753,  1653,  6467, 18188, 49620, 32700])
Test Accuracy: 49.98%


Epoch    80:   0%|          | 0/27 [00:00<?, ?batch/s]

tensor([44598,  7828,  6460, 47109, 30688, 12787,  2980,  8182, 35663,  8838,
         6467, 25206, 20374,  2029, 45954, 36737, 43362, 40989, 14798, 18416,
        41812, 46519, 41952, 21637, 27115, 20202, 32643, 48392,  6039, 18619,
        27197, 44855, 18481, 43600,  5566, 19640,  2056,  2899, 10855,  4307])
Test Accuracy: 49.59%


Epoch    81:   0%|          | 0/28 [00:00<?, ?batch/s]

tensor([35082, 45904, 33574, 44185,  6424, 39093,   466, 32986, 25428, 11639,
        48567, 11547, 43584, 14448, 49291, 12347, 11218, 10601, 34004, 49580,
        30870,  5623, 17615,  4657,  2980, 43918, 28142, 49663, 37904,  1594,
        12643, 11764, 34698, 36478, 23327, 16484,  6961, 17202, 28812, 28515])
Test Accuracy: 49.65%


Epoch    82:   0%|          | 0/28 [00:00<?, ?batch/s]

tensor([ 6889, 14704, 28777, 21369, 43073,  4295,  5873, 43256, 38044, 18876,
        21784, 19944,   727, 13124, 14479, 20290, 30735,  8677, 11760, 44246,
        12076,  7000, 26067, 38722, 11639, 44406, 23715, 29154, 27040, 33608,
        21843, 38329, 47325, 30137,  5763, 33415, 21058, 16808, 22055, 17292])
Test Accuracy: 49.52%


Epoch    83:   0%|          | 0/29 [00:00<?, ?batch/s]

tensor([36235,  8869, 23488,   224, 14723, 16258, 33574, 13402, 10965, 44628,
        28528,  9632, 28447,  3795, 48361,  3950, 16605, 42117, 41791, 14642,
        28703,  8696, 36351, 19965, 22201,  6460, 41875, 10556,  7198, 16689,
        10816,  5855,  9484, 39969,  4957, 22710, 32714,  2392,  5136,  1272])
Test Accuracy: 49.89%


Epoch    84:   0%|          | 0/30 [00:00<?, ?batch/s]

tensor([43597, 36052, 21200, 25428,  5873,  2542, 43188, 11469, 29342, 48499,
         6455, 32932, 35002,    86, 47585, 49502, 10596, 40989, 45921, 31815,
        40212,  7034, 25500, 35500, 46714,  3941, 19868,  7881,  9632, 49644,
        11964, 27329, 37414,  3624, 35296, 28954, 15626,  5615, 16933, 48725])
Test Accuracy: 51.8%


Epoch    85:   0%|          | 0/30 [00:00<?, ?batch/s]

tensor([34064, 16933, 23755,  6885, 19629, 17931, 25227, 40432,  9360, 41058,
         1375, 18380, 44854, 27023, 42602, 12735,  7943, 33588, 35685, 27948,
        31156, 28791,  8669, 22271, 42371, 13854, 27104, 23722, 27918,  5599,
         7039, 26581, 28116,  2086, 21784, 43826, 13747, 33324, 49447, 43362])
Test Accuracy: 52.7%


Epoch    86:   0%|          | 0/31 [00:00<?, ?batch/s]

tensor([42155, 18172,  2015, 49027, 24021, 28199, 37250, 11217, 25416, 19802,
        13618, 34698, 17942,  1628,  8041, 42981, 30620, 36046, 41667, 19948,
        12489, 34908, 38869,  5736, 27329, 14526,  3272,  9116, 18979, 40532,
        18833, 33574, 43953, 21200, 16764, 24738, 25355,  7044, 23696, 36795])
Test Accuracy: 51.19%


Epoch    87:   0%|          | 0/31 [00:00<?, ?batch/s]

tensor([37209, 46817,  9150, 20975, 19153, 12396,  7990, 30133, 10596, 23294,
        29087, 48409, 43362, 27787, 27556, 21591,  5445, 36617, 41952, 42274,
         6023,  7940, 29481, 35023, 17983, 25516, 46796, 13832, 11728,  3709,
        40594,  5370, 30461, 13049, 18188,  9752, 33574, 11763,  3351, 46087])
Test Accuracy: 51.93%


Epoch    88:   0%|          | 0/32 [00:00<?, ?batch/s]

tensor([42762, 11470,  2121, 29108, 24260, 20496, 24292, 10596, 29334, 26730,
        10725, 40935, 38336, 31573, 15370, 11867, 38189, 35800, 13015,  9663,
        30738, 32647,   221, 26286, 30302, 13900, 10816, 28433, 37698, 48751,
        23681, 22323, 27197, 16069, 32535, 49992,  1684, 20059, 11656, 26935])
Test Accuracy: 49.5%


Epoch    89:   0%|          | 0/33 [00:00<?, ?batch/s]

tensor([36901,  1889, 41682, 37478, 11809, 33989,  7407, 32672, 34490, 45395,
        43597, 43953, 28612, 18336,  8447, 47977, 40988, 18418,  3285, 39734,
         1969,  4615,  5104,  9092,   104, 39356, 17389, 31340, 14684, 12987,
         6531,  1782,   337, 39256, 27021,  7496,  7237, 40774, 10872, 26754])
Test Accuracy: 51.61%


Epoch    90:   0%|          | 0/33 [00:00<?, ?batch/s]

tensor([31004,  1048, 46653, 36263, 23583,  6838, 18833, 35921, 34561, 11984,
        17922, 27418, 44713, 28868, 21418, 30140,  9501, 23849, 11274, 48228,
        28715, 15323, 48858,  3739, 41978,  1254,  6208, 31301, 47465, 23071,
        12613, 46385, 10311, 45164, 17920, 11798, 43953,  6962, 43323, 49335])
Test Accuracy: 53.33%


Epoch    91:   0%|          | 0/34 [00:00<?, ?batch/s]

tensor([ 1310, 49958, 39495, 39453,   602, 40490, 37642, 25599, 49125,  2667,
        16205, 49712, 23628, 19706, 47397, 18169, 11455, 16449, 33967, 13388,
        36368,  3424, 17679, 34896, 35824, 11786, 11790,  6375, 40310, 29140,
        25724, 34376, 27739,     5, 32653, 41197, 43509, 14064, 36052, 29947])
Test Accuracy: 54.35%


Epoch    92:   0%|          | 0/35 [00:00<?, ?batch/s]

tensor([41681, 45374, 47226, 18254, 22187, 40079, 34376, 42839, 19851, 29481,
         9632,  6014, 33187, 12579, 26812, 18252, 21229, 40614, 11133, 36875,
         8591, 48441, 24931, 39795, 49936, 23849,  3293, 19267, 38187, 45991,
         2630,  5588, 37832, 10474, 20923, 36734,  5013, 32578, 14392, 27963])
Test Accuracy: 53.14%


Epoch    93:   0%|          | 0/35 [00:00<?, ?batch/s]

tensor([16902, 43127, 21632,  1601, 42023, 16483, 23337, 29423, 20730, 13017,
        12668,  6797,  5445,  7295, 37849, 31413,  6278,  1310,  1254, 11438,
        31842, 20841,   380, 25472, 19856, 21693, 29441, 42176, 27063, 23986,
          302, 12433,  3341, 30416, 12516, 27682, 41564, 33691, 41820, 33815])
Test Accuracy: 54.27%


Epoch    94:   0%|          | 0/36 [00:00<?, ?batch/s]

tensor([40099, 43769, 31323, 40739, 11048, 33605, 24021, 29346,  1553, 10108,
        29481, 11790, 47832, 42727, 31180, 38012, 10435, 14035, 25964, 13081,
        14198, 10873, 45789, 20250, 12350, 37478, 29565,  6320, 21886, 42762,
        28299, 32643, 40462, 46384, 29087, 26428,  4624, 35981, 42237, 44807])
Test Accuracy: 54.21%


Epoch    95:   0%|          | 0/36 [00:00<?, ?batch/s]

tensor([25705, 22281, 41676, 27047, 30137, 38627, 41576, 49757, 38555, 41409,
        17339,  4978, 19009, 44005, 26998,     5, 36981, 17499, 28559,  4699,
        39530,  3557, 26287,  7742, 31629, 16894, 25029, 12672, 46017,  4994,
        30241,  6753, 37726,  8869, 25918,  3572, 28860, 36529, 31374, 28845])
Test Accuracy: 55.45%


Epoch    96:   0%|          | 0/37 [00:00<?, ?batch/s]

tensor([37366, 30121, 35824, 25471, 12812, 18916, 21500, 29481, 19508, 13550,
        13854, 48138,  9515, 40438, 47046,  8660, 19851, 37320, 12687, 29838,
         6279, 49003, 14011, 49650, 41186, 44767, 14964, 35908,  4451, 13856,
        48344, 29218, 13365, 48007, 25564, 25706, 18041, 32656,  9703, 32633])
Test Accuracy: 54.42%


Epoch    97:   0%|          | 0/38 [00:00<?, ?batch/s]

tensor([33022, 43815, 42812, 15917, 36843, 22615, 48344, 20560,  2544, 33196,
        18294,  9760, 30621,  7342, 10148,  9786, 41631,   820, 37811, 40774,
        42794, 16552, 30201, 48086, 14520, 47002, 27499,  6797, 37344, 29161,
        24426, 40669,  4784,  5073, 41606,   647,   135, 23983, 21805,   701])
Test Accuracy: 56.25%


Epoch    98:   0%|          | 0/38 [00:00<?, ?batch/s]

tensor([13159, 25460, 49992, 49201, 32802, 11583, 23903, 38692, 11813, 33575,
        18059, 19713, 46703, 41721, 40614, 23071,  6012, 22351, 24087, 18580,
        28139, 43848, 46447, 22275, 49357, 22126, 43982,  2656, 44598, 46060,
        16401, 46256, 42010, 27488, 16601, 16213, 17343, 21577, 26750, 48143])
Test Accuracy: 56.73%


Epoch    99:   0%|          | 0/39 [00:00<?, ?batch/s]

tensor([  472, 24124, 18700, 31288,  3590, 10855, 36052, 12396, 46267, 23567,
        19944, 38436, 32073, 38943,  5014,  3739,   835,  4318, 47690, 44393,
        30446, 33701, 41353, 36901, 34555, 29441, 49279, 31700, 31864, 30224,
        11988, 26088, 34254,  2791, 34071, 40852, 45352,  5071, 33553, 14029])
Test Accuracy: 57.05%


Epoch   100:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 55.66%


Epoch   101:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 57.6%


Epoch   102:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 57.98%


Epoch   103:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 57.99%


Epoch   104:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 58.07%


Epoch   105:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 57.61%


Epoch   106:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 58.45%


Epoch   107:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 59.83%


Epoch   108:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 59.89%


Epoch   109:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 60.28%


Epoch   110:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 59.76%


Epoch   111:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 60.41%


Epoch   112:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 61.14%


Epoch   113:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 60.55%


Epoch   114:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 60.31%


Epoch   115:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 61.43%


Epoch   116:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 61.92%


Epoch   117:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.18%


Epoch   118:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 60.14%


Epoch   119:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.1%


Epoch   120:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 61.49%


Epoch   121:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 61.27%


Epoch   122:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.94%


Epoch   123:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 61.96%


Epoch   124:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.49%


Epoch   125:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.56%


Epoch   126:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.71%


Epoch   127:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.99%


Epoch   128:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.12%


Epoch   129:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.61%


Epoch   130:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.0%


Epoch   131:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.81%


Epoch   132:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.25%


Epoch   133:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 62.96%


Epoch   134:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.39%


Epoch   135:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.25%


Epoch   136:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.59%


Epoch   137:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.88%


Epoch   138:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.91%


Epoch   139:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.7%


Epoch   140:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.79%


Epoch   141:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.96%


Epoch   142:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.93%


Epoch   143:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.58%


Epoch   144:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.84%


Epoch   145:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.73%


Epoch   146:   0%|          | 0/40 [00:00<?, ?batch/s]

Test Accuracy: 63.94%


Epoch   147:   0%|          | 0/40 [00:00<?, ?batch/s]

In [None]:
trainset_sup_new = create_labeled_dataset_active_learning(trainset_unsup, selected_indices)

In [None]:
trainset_sup_new[-1]

In [None]:
plt.imshow(trainset_sup_new[-1][0].permute(1, 2, 0).cpu().numpy())
plt.title(f'Pred: {trainset_sup_new[-1][1]}')
plt.show()

plt.imshow(trainset_sup[-1][0].permute(1, 2, 0).cpu().numpy())
plt.title(f'Pred: {trainset_sup[-1][1]}')
plt.show()

In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_10_fixmatch_AL.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(TARGET_PROP*100)}% - {test_accuracy:.2f}% - Active Learning")
fig1.savefig(f"./figures/test_score_{TARGET_PROP}_fixmatch_AL.png")

### III.2 Fixmatch on 5% train data

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 50
SUBSET_PROP = 0.05

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = K.ImageSequential(
    K.RandomHorizontalFlip(p=0.50), 
    K.RandomAffine(degrees=0, translate=(0.125, 0.125)),
)

strong_transform = K.ImageSequential(
    K.auto.RandAugment(n=2, m=10), # randaugment + cutout
)

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


# Define the cosine learning rate decay function
lr_lambda = lambda step: LR * torch.cos(torch.tensor((7 * torch.pi * (step)) / (16 * EPOCHS))) * 100 / 3

# Create a learning rate scheduler with the cosine decay function
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# scheduler = None

In [None]:
print("Start training")

train_losses = []
test_losses = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    max_confidence = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {epoch: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # normalize
        weak_labeled_inputs = normalize(data=weak_labeled_inputs, mean=mean, std=std)
        weak_unlabeled_inputs = normalize(data=weak_unlabeled_inputs, mean=mean, std=std)
        strong_unlabeled_inputs = normalize(data=strong_unlabeled_inputs, mean=mean, std=std)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)
            # outputs = torch.softmax(outputs, dim=1)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)
            # labeled_outputs = torch.softmax(labeled_outputs, dim=1)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        

        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled
        max_confidence = max(max_confidence, max_qb.max().item())

        

        # update progress bar
        pbar.set_postfix({
            "total loss": loss.item(),
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "confidence": max_confidence,
            "n_unlabeled": running_n_unlabeled,
            "lr": optimizer.param_groups[0]['lr'].item()
        })

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)


In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_5_fixmatch.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}%")
fig1.savefig(f"./figures/test_score_{SUBSET_PROP}_fixmatch.png")

### III.3 Fixmatch on 1% train data

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 50
SUBSET_PROP = 0.01

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = K.ImageSequential(
    K.RandomHorizontalFlip(p=0.50), 
    K.RandomAffine(degrees=0, translate=(0.125, 0.125)),
)

strong_transform = K.ImageSequential(
    K.auto.RandAugment(n=2, m=10), # randaugment + cutout
)

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


# Define the cosine learning rate decay function
lr_lambda = lambda step: LR * torch.cos(torch.tensor((7 * torch.pi * (step)) / (16 * EPOCHS))) * 100 / 3

# Create a learning rate scheduler with the cosine decay function
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# scheduler = None

In [None]:
print("Start training")

train_losses = []
test_losses = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    max_confidence = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {epoch: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # normalize
        weak_labeled_inputs = normalize(data=weak_labeled_inputs, mean=mean, std=std)
        weak_unlabeled_inputs = normalize(data=weak_unlabeled_inputs, mean=mean, std=std)
        strong_unlabeled_inputs = normalize(data=strong_unlabeled_inputs, mean=mean, std=std)

        # Compute mask, confidence
        pseudo_labels, idx, max_qb = mask(model, weak_unlabeled_inputs)
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)
            # outputs = torch.softmax(outputs, dim=1)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)
            # labeled_outputs = torch.softmax(labeled_outputs, dim=1)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        

        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled
        max_confidence = max(max_confidence, max_qb.max().item())

        

        # update progress bar
        pbar.set_postfix({
            "total loss": loss.item(),
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "confidence": max_confidence,
            "n_unlabeled": running_n_unlabeled,
            "lr": optimizer.param_groups[0]['lr'].item()
        })

    # update loss
    train_losses.append(running_loss / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)


In [None]:
# plot losses
plt.figure(figsize=(10, 7))
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.title('Loss at the end of each epoch')
plt.legend()
plt.show()

# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot()
    plt.tight_layout()
    plt.show()


In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_1_fixmatch.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}%")
fig1.savefig(f"./figures/test_score_{SUBSET_PROP}_fixmatch.png")