In [None]:
# !pip install torchview torchsummary torchvision kornia torchmetrics matplotlib tqdm path graphviz opencv-python scikit-learn optuna

In [1]:
# deep learning
import torch
import torch.nn as nn
from torch.distributions.transforms import LowerCholeskyTransform
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, Dataset

# vizualisation
import torchsummary

# transforms
import torchvision.transforms as T
import kornia.augmentation as K
from kornia.enhance import normalize
from torchvision.transforms import RandAugment

# metrics
from torchmetrics import Accuracy

# torchvision
import torchvision
import torchvision.transforms as transforms

# plotting
import matplotlib.pyplot as plt
from torchview import draw_graph

from IPython.display import display
from IPython.core.display import SVG, HTML

from tqdm.auto import tqdm

# typing
from typing import Callable

from utils import plot_images, plot_transform
from model import ConvNN, display_model

# os
import os
import path

import random
import numpy as np 

# transformations
# import transform as T
from randaugment import RandAugmentMC

# typing
from typing import Callable, List, Tuple

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

%load_ext autoreload
%autoreload 2

In [2]:
DEFAULT_RANDOM_SEED = 2021

def seedBasic(seed=DEFAULT_RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
# torch random seed
import torch
def seedTorch(seed=DEFAULT_RANDOM_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
# basic + tensorflow + torch 
def seedEverything(seed=DEFAULT_RANDOM_SEED):
    seedBasic(seed)
    seedTorch(seed)

In [3]:
# Set device
if ((int(torch.__version__.split(".")[0]) >= 2) or (int(torch.__version__.split(".")[1]) >= 13)) and torch.has_mps:
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [4]:
IMG_SHAPE = (3, 32, 32)
# See Table 4
TAU = 0.9 
LAMBDA_U = 3
MU = 4 # Coefficient for scaling the unlabeled loss
BATCH_SIZE = 64
LR = 0.03
BETA = 0.9
WEIGHT_DECAY = 0.001

In [5]:
class ConvNN(nn.Module):
    """
    Simple CNN for CIFAR10
    """
    
    def __init__(self):
        super().__init__()
        self.conv_32 = nn.Conv2d(3, 32, kernel_size=3, padding='same')
        self.conv_64 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.conv_96 = nn.Conv2d(64, 96, kernel_size=3, padding='same')
        self.conv_128 = nn.Conv2d(96, 128, kernel_size=3, padding='same')
        self.fc_512 = nn.Linear(512, 512)
        self.fc_10 = nn.Linear(512, 10)
        self.max_pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU(inplace=True)
        self.flatten = nn.Flatten()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_32(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_64(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_96(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.conv_128(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.flatten(x)
        x = self.fc_512(x)
        x = self.relu(x)
        x = self.fc_10(x)

        return x

In [6]:
def compute_mean_std(trainLoader) -> Tuple[List[float], List[float]]:
    # initialize the list of means and stds
    mean, std = torch.zeros(3), torch.zeros(3)

    # iterate over the dataset and compute the sum of each channel
    for images, _ in trainLoader:
        mean+= torch.mean(images, dim=[0,2,3])
        std+= torch.std(images, dim=[0,2,3])
    
    # compute the mean and std
    mean = mean/len(trainLoader)
    std = std/len(trainLoader)

    return mean, std

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

if not os.path.exists('./data/mean.pt'):
    mean, std = compute_mean_std(trainloader)
    torch.save(mean, 'data/mean.pt')
    torch.save(std, 'data/std.pt')
else:
    mean, std = torch.load('./data/mean.pt'), torch.load('./data/std.pt')

# to numpy
mean, std = mean.numpy(), std.numpy()

print(f"mean: {mean}, std: {std}")


testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
mean: [0.4913966  0.48215377 0.44651437], std: [0.246344   0.24280126 0.26067406]
Files already downloaded and verified


In [7]:
torch_models = 'torch_models' 
if not os.path.exists(torch_models):
    os.makedirs(torch_models)

## IV. Semi-Supervised Learning: Fixmatch - Data Alignment

### IV.1 Fixmatch on 10% train data - Data Alignment

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 300
SUBSET_PROP = 0.10
CIFAR10_class_distribution = 0.10

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

In [None]:
# transformations
weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=0, translate=(0.125, 0.125)),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

strong_transform = transforms.Compose([
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomAffine(degrees=10, translate=(0.125, 0.125)),
    # transforms.RandAugment(num_ops=2, magnitude=10),
    RandAugmentMC(n=2, m=10),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
    

In [None]:
def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

In [None]:
model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
print("Start training")

train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    running_accuracy = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {epoch: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # normalize labeled and unlabeled inputs
        labeled_inputs_norm, unlabeled_inputs_norm = normalize(labeled_inputs, mean, std), normalize(unlabeled_inputs, mean, std)

        unlabeled_outputs_norm = model(unlabeled_inputs_norm)

        # Compute moving average of labeled and unlabeled predictions
        moving_avg_pred_unlabeled = (i * moving_avg_pred_unlabeled + unlabeled_outputs_norm.shape[0]) / (i + 1)

        # ratio
        ratio = CIFAR10_class_distribution / moving_avg_pred_unlabeled

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # prediction on weak augmented unlabeled data
        qb = model(weak_unlabeled_inputs)
        qb = torch.softmax(qb, dim=1)
        qb_norm = qb * ratio

        # normalize
        qb_tilde = qb_norm / torch.sum(qb_norm, dim=1, keepdim=True)

        # compute mask
        max_qb_tilde, qb_tilde_hat = torch.max(qb_tilde, dim=1)
        idx = max_qb_tilde > TAU

        # pseudo labels
        pseudo_labels = qb_tilde_hat[idx]

        # mask strong augmented unlabeled data
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)
            # outputs = torch.softmax(outputs, dim=1)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)
            # labeled_outputs = torch.softmax(labeled_outputs, dim=1)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        

        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled
        running_accuracy += 100 * correct / total

        

        # update progress bar
        pbar.set_postfix({
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg confidence": torch.mean(max_qb_tilde).item(),
            "n_unlabeled": running_n_unlabeled,
            "lr": optimizer.param_groups[0]['lr']
        })

    # update loss
    train_losses.append(running_loss / (i + 1))
    train_accuracies.append(running_accuracy / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)
        test_accuracies.append(test_accuracy)

In [None]:
# plot losses and accuracies
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses, label="train")
ax1.plot(test_losses, label="test")
ax1.set_title("Loss")

ax2.plot(train_accuracies, label="train")
ax2.plot(test_accuracies, label="test")
ax2.set_title("Accuracy")

plt.legend()
plt.show()

# save plot
fig.savefig(f"figures/unsup_DA_10_losses_accuracies.png")

In [None]:
# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    fig, ax = plt.subplots(figsize=(10, 10))
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot(ax=ax)
    plt.tight_layout()
    plt.show()

    # save plot
    fig.savefig(f"figures/unsup_DA_10_confusion_matrix.png")

In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_10_fixmatch_DA.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}% - Distribution Alignment")
fig1.savefig(f"./figures/test_score_10_fixmatch_DA.png")

### IV.2 Fixmatch on 5% train data - Data Alignment

In [None]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 200 # 55
SUBSET_PROP = 0.05
CIFAR10_class_distribution = 0.10

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=0, translate=(0.125, 0.125)),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

strong_transform = transforms.Compose([
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomAffine(degrees=10, translate=(0.125, 0.125)),
    # transforms.RandAugment(num_ops=2, magnitude=10),
    RandAugmentMC(n=2, m=10),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
    

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
print("Start training")

train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    running_accuracy = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {epoch: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # normalize labeled and unlabeled inputs
        labeled_inputs_norm, unlabeled_inputs_norm = normalize(labeled_inputs, mean, std), normalize(unlabeled_inputs, mean, std)

        # labeled_outputs_norm = model(labeled_inputs_norm)
        unlabeled_outputs_norm = model(unlabeled_inputs_norm)

        # Compute moving average of labeled and unlabeled predictions
        # moving_avg_pred_labeled = (i * moving_avg_pred_labeled + labeled_outputs_norm.shape[0]) / (i + 1)
        moving_avg_pred_unlabeled = (i * moving_avg_pred_unlabeled + unlabeled_outputs_norm.shape[0]) / (i + 1)

        # ratio
        ratio = CIFAR10_class_distribution / moving_avg_pred_unlabeled
        # model.train()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)
        strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

        # prediction on weak augmented unlabeled data
        qb = model(weak_unlabeled_inputs)
        qb = torch.softmax(qb, dim=1)
        qb_norm = qb * ratio

        # normalize
        qb_tilde = qb_norm / torch.sum(qb_norm, dim=1, keepdim=True)

        # compute mask
        max_qb_tilde, qb_tilde_hat = torch.max(qb_tilde, dim=1)
        idx = max_qb_tilde > TAU

        # pseudo labels
        pseudo_labels = qb_tilde_hat[idx]

        # mask strong augmented unlabeled data
        strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

        n_labeled, n_unlabeled = weak_labeled_inputs.size(0), strong_unlabeled_inputs.size(0)

        if n_unlabeled != 0:
            # Concatenate labeled and unlabeled data
            inputs_all = torch.cat((weak_labeled_inputs, strong_unlabeled_inputs))
            labels_all = torch.cat((labels, pseudo_labels))

            # forward pass
            outputs = model(inputs_all)
            # outputs = torch.softmax(outputs, dim=1)

            # split labeled and unlabeled outputs
            labeled_outputs, unlabeled_outputs = outputs[:n_labeled], outputs[n_labeled:]

            # compute losses
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels_all.size(0)
            correct += (outputs.argmax(dim=1) == labels_all).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)
            # labeled_outputs = torch.softmax(labeled_outputs, dim=1)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        

        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += n_unlabeled
        running_accuracy += 100 * correct / total

        

        # update progress bar
        pbar.set_postfix({
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg confidence": torch.mean(max_qb_tilde).item(),
            "n_unlabeled": running_n_unlabeled,
            "lr": optimizer.param_groups[0]['lr']
        })

    # update loss
    train_losses.append(running_loss / (i + 1))
    train_accuracies.append(running_accuracy / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)
        test_accuracies.append(test_accuracy)

In [None]:
# plot losses and accuracies
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses, label="train")
ax1.plot(test_losses, label="test")
ax1.set_title("Loss")

ax2.plot(train_accuracies, label="train")
ax2.plot(test_accuracies, label="test")
ax2.set_title("Accuracy")

plt.legend()
plt.show()

# save plot
fig.savefig(f"figures/unsup_DA_5_losses_accuracies.png")

In [None]:
# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    fig, ax = plt.subplots(figsize=(10, 10))
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot(ax=ax)
    plt.tight_layout()
    plt.show()

    # save plot
    fig.savefig(f"figures/unsup_DA_5_confusion_matrix.png")

In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_5_fixmatch_DA.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}% - Data Alignment")
fig1.savefig(f"./figures/test_score_5_fixmatch_DA.png")

### IV.3 Fixmatch on 1% train data - Data Alignment

In [8]:
# Define your dataset and dataloaders for labeled and unlabeled data
seedEverything()

EPOCHS = 300
SUBSET_PROP = 0.01
CIFAR10_class_distribution = 0.10 

# 10% labeled data and 100% unlabeled (see note 2 in paper)
trainset_sup, _ = torch.utils.data.random_split(trainset, [SUBSET_PROP, 1-SUBSET_PROP])

trainset_unsup, _ = torch.utils.data.random_split(trainset, [1, 0])

labeled_dataloader = DataLoader(
    trainset_sup,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

unlabeled_dataloader = DataLoader(
    trainset_unsup,
    batch_size=MU*BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# transformations
weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=0, translate=(0.125, 0.125)),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

strong_transform = transforms.Compose([
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomAffine(degrees=10, translate=(0.125, 0.125)),
    # transforms.RandAugment(num_ops=2, magnitude=10),
    RandAugmentMC(n=2, m=10),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
    

def mask(model, weak_unlabeled_data):
    with torch.no_grad():
        model.train()

        qb = model(weak_unlabeled_data)

        # qb = logits.copy()
        qb = torch.softmax(qb, dim=1)

        max_qb, qb_hat = torch.max(qb, dim=1)

        idx = max_qb > TAU
        qb_hat = qb_hat[idx]

    return qb_hat.detach(), idx, max_qb.detach()

model = ConvNN().to(device)

# criterion and optimizer
labeled_criterion = nn.CrossEntropyLoss(reduction='none')
unlabeled_criterion = nn.CrossEntropyLoss(reduction='none')

optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=BETA, weight_decay=WEIGHT_DECAY, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)



In [9]:
print("Start training")

train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    running_n_unlabeled = 0
    running_accuracy = 0
    moving_avg_pred_labeled = 0
    moving_avg_pred_unlabeled = 0


    pbar = tqdm(zip(labeled_dataloader, unlabeled_dataloader), total=min(len(labeled_dataloader), len(unlabeled_dataloader)), unit="batch", desc=f"Epoch {epoch: >5}")

    for i, (labeled_data, unlabeled_data) in enumerate(pbar):
        # Get labeled and unlabeled data
        labeled_inputs, labels = labeled_data[0].to(device), labeled_data[1].to(device)
        unlabeled_inputs, _ = unlabeled_data[0].to(device), unlabeled_data[1].to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # normalize labeled and unlabeled inputs
        labeled_inputs_norm, unlabeled_inputs_norm = normalize(labeled_inputs, mean, std), normalize(unlabeled_inputs, mean, std)

        # labeled_outputs_norm = model(labeled_inputs_norm)
        unlabeled_outputs_norm = model(unlabeled_inputs_norm)

        # Compute moving average of labeled and unlabeled predictions
        # moving_avg_pred_labeled = (i * moving_avg_pred_labeled + labeled_outputs_norm.shape[0]) / (i + 1)
        moving_avg_pred_unlabeled = (i * moving_avg_pred_unlabeled + unlabeled_outputs_norm.shape[0]) / (i + 1)

        # ratio
        ratio = CIFAR10_class_distribution / moving_avg_pred_unlabeled
        # model.train()

        # Apply weak augmentation to labeled data
        weak_labeled_inputs = weak_transform(labeled_inputs)

        # Apply strong augmentation + weak augmentation to unlabeled data
        weak_unlabeled_inputs = weak_transform(unlabeled_inputs)

        # prediction on weak augmented unlabeled data
        qb = model(weak_unlabeled_inputs)
        qb = torch.softmax(qb, dim=1)
        qb_norm = qb * ratio

        # normalize
        qb_tilde = qb_norm / torch.sum(qb_norm, dim=1, keepdim=True)

        # compute mask
        max_qb_tilde, qb_tilde_hat = torch.max(qb_tilde, dim=1)
        idx = max_qb_tilde >= TAU

        # pseudo labels
        pseudo_labels = qb_tilde_hat[idx]

        n_labeled = weak_labeled_inputs.size(0)

        if idx.sum().cpu().item() != 0:
            # forward weak augmented labeled data
            labeled_outputs = model(weak_labeled_inputs)

            # transform unlabeled data
            strong_unlabeled_inputs = strong_transform(unlabeled_inputs)

            # apply mask
            strong_unlabeled_inputs = strong_unlabeled_inputs[idx]

            # forward pass
            unlabeled_outputs = model(strong_unlabeled_inputs)

            # compute unlabeled loss
            unlabeled_loss = torch.sum(unlabeled_criterion(unlabeled_outputs, pseudo_labels)) / (MU * BATCH_SIZE)


            # compute labeled loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item() + (unlabeled_outputs.argmax(dim=1) == pseudo_labels).sum().item()
            
        else:
            # forward pass
            labeled_outputs = model(weak_labeled_inputs)

            # compute loss
            labeled_loss = torch.sum(labeled_criterion(labeled_outputs, labels)) / BATCH_SIZE
            unlabeled_loss = torch.tensor(0, device=device)

            # compute total loss
            loss = labeled_loss + LAMBDA_U * unlabeled_loss

            # compute accuracy
            total += labels.size(0)
            correct += (labeled_outputs.argmax(dim=1) == labels).sum().item()


        # backward pass + optimize
        loss.backward()
        optimizer.step()

        

        # update statistics
        running_loss += loss.item()
        running_n_unlabeled += idx.sum().cpu().item()
        running_accuracy += 100 * correct / total

        

        # update progress bar
        pbar.set_postfix({
            "labeled loss": labeled_loss.item(),
            "unlabeled loss": unlabeled_loss.item(),
            "accuracy": 100 * correct / total,
            "avg confidence": torch.mean(max_qb_tilde).item(),
            "n_unlabeled": running_n_unlabeled,
            "lr": optimizer.param_groups[0]['lr']
        })

    # update loss
    train_losses.append(running_loss / (i + 1))
    train_accuracies.append(running_accuracy / (i + 1))

    # scheduler step
    if scheduler is not None:
        scheduler.step()

    
    # Evaluate the model on the test set
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # normalize
            images = normalize(data=images, mean=mean, std=std)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
        
        test_accuracy = 100.0 * test_correct / test_total
        print(f'Test Accuracy: {test_accuracy}%')

        # update loss
        test_losses.append(torch.sum(labeled_criterion(outputs, labels)).item() / BATCH_SIZE)
        test_accuracies.append(test_accuracy)

Start training


Epoch     0:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     1:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 10.0%


Epoch     2:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 10.13%


Epoch     3:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 14.53%


Epoch     4:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 20.09%


Epoch     5:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 18.54%


Epoch     6:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 22.17%


Epoch     7:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 23.76%


Epoch     8:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 23.66%


Epoch     9:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 24.23%


Epoch    10:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 25.23%


Epoch    11:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 27.3%


Epoch    12:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 28.01%


Epoch    13:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.02%


Epoch    14:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 26.52%


Epoch    15:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 32.12%


Epoch    16:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 32.45%


Epoch    17:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.94%


Epoch    18:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 33.8%


Epoch    19:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 33.78%


Epoch    20:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 28.37%


Epoch    21:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 36.83%


Epoch    22:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.13%


Epoch    23:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 33.86%


Epoch    24:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 36.85%


Epoch    25:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.54%


Epoch    26:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 34.48%


Epoch    27:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.88%


Epoch    28:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.17%


Epoch    29:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.98%


Epoch    30:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.5%


Epoch    31:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 31.5%


Epoch    32:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.29%


Epoch    33:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.62%


Epoch    34:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.8%


Epoch    35:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.36%


Epoch    36:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.34%


Epoch    37:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 34.8%


Epoch    38:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 36.15%


Epoch    39:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.31%


Epoch    40:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.56%


Epoch    41:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.79%


Epoch    42:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.51%


Epoch    43:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.1%


Epoch    44:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.08%


Epoch    45:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.87%


Epoch    46:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 35.82%


Epoch    47:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.82%


Epoch    48:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.27%


Epoch    49:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.15%


Epoch    50:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.78%


Epoch    51:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.02%


Epoch    52:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.61%


Epoch    53:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.86%


Epoch    54:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.35%


Epoch    55:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.96%


Epoch    56:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.59%


Epoch    57:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.43%


Epoch    58:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.55%


Epoch    59:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.22%


Epoch    60:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.71%


Epoch    61:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 37.84%


Epoch    62:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 38.91%


Epoch    63:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.81%


Epoch    64:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.94%


Epoch    65:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.56%


Epoch    66:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.78%


Epoch    67:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.3%


Epoch    68:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.06%


Epoch    69:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.3%


Epoch    70:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.51%


Epoch    71:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.91%


Epoch    72:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.76%


Epoch    73:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.95%


Epoch    74:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 39.19%


Epoch    75:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.52%


Epoch    76:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.87%


Epoch    77:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.34%


Epoch    78:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.54%


Epoch    79:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.0%


Epoch    80:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.09%


Epoch    81:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.09%


Epoch    82:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.34%


Epoch    83:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.57%


Epoch    84:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.3%


Epoch    85:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 40.48%


Epoch    86:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.2%


Epoch    87:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.52%


Epoch    88:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.45%


Epoch    89:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.58%


Epoch    90:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.55%


Epoch    91:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.74%


Epoch    92:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.47%


Epoch    93:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.52%


Epoch    94:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.81%


Epoch    95:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.53%


Epoch    96:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.92%


Epoch    97:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.88%


Epoch    98:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.51%


Epoch    99:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.72%


Epoch   100:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.47%


Epoch   101:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.51%


Epoch   102:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.96%


Epoch   103:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.3%


Epoch   104:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.14%


Epoch   105:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.0%


Epoch   106:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.06%


Epoch   107:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.17%


Epoch   108:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.98%


Epoch   109:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.8%


Epoch   110:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.39%


Epoch   111:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.93%


Epoch   112:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.49%


Epoch   113:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.82%


Epoch   114:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.21%


Epoch   115:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.43%


Epoch   116:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.56%


Epoch   117:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.57%


Epoch   118:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 41.23%


Epoch   119:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.96%


Epoch   120:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.66%


Epoch   121:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.86%


Epoch   122:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.51%


Epoch   123:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.68%


Epoch   124:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.17%


Epoch   125:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.68%


Epoch   126:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.47%


Epoch   127:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.59%


Epoch   128:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.61%


Epoch   129:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.39%


Epoch   130:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.67%


Epoch   131:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.93%


Epoch   132:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.39%


Epoch   133:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.33%


Epoch   134:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.24%


Epoch   135:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.54%


Epoch   136:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.46%


Epoch   137:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.58%


Epoch   138:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.46%


Epoch   139:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.76%


Epoch   140:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.42%


Epoch   141:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.9%


Epoch   142:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.74%


Epoch   143:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.61%


Epoch   144:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.74%


Epoch   145:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.89%


Epoch   146:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.51%


Epoch   147:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.43%


Epoch   148:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.91%


Epoch   149:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.37%


Epoch   150:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.32%


Epoch   151:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.36%


Epoch   152:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.26%


Epoch   153:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.29%


Epoch   154:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.18%


Epoch   155:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.25%


Epoch   156:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.91%


Epoch   157:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.36%


Epoch   158:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.44%


Epoch   159:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.14%


Epoch   160:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.72%


Epoch   161:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.61%


Epoch   162:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.39%


Epoch   163:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.39%


Epoch   164:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.45%


Epoch   165:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.84%


Epoch   166:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.08%


Epoch   167:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.93%


Epoch   168:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.25%


Epoch   169:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.13%


Epoch   170:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.09%


Epoch   171:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.8%


Epoch   172:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.56%


Epoch   173:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.64%


Epoch   174:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.57%


Epoch   175:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.03%


Epoch   176:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 42.63%


Epoch   177:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.85%


Epoch   178:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.34%


Epoch   179:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.14%


Epoch   180:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.99%


Epoch   181:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.4%


Epoch   182:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.12%


Epoch   183:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.97%


Epoch   184:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.33%


Epoch   185:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.2%


Epoch   186:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.73%


Epoch   187:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.33%


Epoch   188:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.01%


Epoch   189:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.89%


Epoch   190:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.95%


Epoch   191:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.5%


Epoch   192:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.65%


Epoch   193:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.46%


Epoch   194:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.82%


Epoch   195:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.51%


Epoch   196:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.64%


Epoch   197:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.11%


Epoch   198:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.95%


Epoch   199:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.73%


Epoch   200:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.17%


Epoch   201:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.34%


Epoch   202:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.73%


Epoch   203:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.18%


Epoch   204:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.93%


Epoch   205:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.84%


Epoch   206:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.02%


Epoch   207:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.72%


Epoch   208:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.52%


Epoch   209:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 43.31%


Epoch   210:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.45%


Epoch   211:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.44%


Epoch   212:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.49%


Epoch   213:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.89%


Epoch   214:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.31%


Epoch   215:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.78%


Epoch   216:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.8%


Epoch   217:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.78%


Epoch   218:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.71%


Epoch   219:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.85%


Epoch   220:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.39%


Epoch   221:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.35%


Epoch   222:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.52%


Epoch   223:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.47%


Epoch   224:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.13%


Epoch   225:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 44.92%


Epoch   226:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.94%


Epoch   227:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.88%


Epoch   228:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.02%


Epoch   229:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.18%


Epoch   230:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.06%


Epoch   231:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.27%


Epoch   232:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.14%


Epoch   233:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.92%


Epoch   234:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.51%


Epoch   235:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.09%


Epoch   236:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.18%


Epoch   237:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.02%


Epoch   238:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.68%


Epoch   239:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 45.69%


Epoch   240:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.01%


Epoch   241:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.7%


Epoch   242:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.71%


Epoch   243:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.75%


Epoch   244:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.75%


Epoch   245:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.64%


Epoch   246:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.18%


Epoch   247:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.74%


Epoch   248:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.83%


Epoch   249:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.83%


Epoch   250:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.43%


Epoch   251:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.49%


Epoch   252:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.29%


Epoch   253:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.15%


Epoch   254:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.89%


Epoch   255:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.06%


Epoch   256:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.82%


Epoch   257:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.59%


Epoch   258:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.91%


Epoch   259:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.91%


Epoch   260:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.99%


Epoch   261:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.88%


Epoch   262:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.97%


Epoch   263:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.0%


Epoch   264:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.63%


Epoch   265:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.62%


Epoch   266:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.4%


Epoch   267:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.42%


Epoch   268:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.69%


Epoch   269:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.59%


Epoch   270:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 46.84%


Epoch   271:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.08%


Epoch   272:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.06%


Epoch   273:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.07%


Epoch   274:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.15%


Epoch   275:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.1%


Epoch   276:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.26%


Epoch   277:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.4%


Epoch   278:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.68%


Epoch   279:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.79%


Epoch   280:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.84%


Epoch   281:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.99%


Epoch   282:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.94%


Epoch   283:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 47.98%


Epoch   284:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.06%


Epoch   285:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.04%


Epoch   286:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.02%


Epoch   287:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.06%


Epoch   288:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.05%


Epoch   289:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.1%


Epoch   290:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.06%


Epoch   291:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.11%


Epoch   292:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.11%


Epoch   293:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.12%


Epoch   294:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.15%


Epoch   295:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.18%


Epoch   296:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.18%


Epoch   297:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.18%


Epoch   298:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.18%


Epoch   299:   0%|          | 0/8 [00:00<?, ?batch/s]

Test Accuracy: 48.18%


In [None]:
# plot losses and accuracies
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses, label="train")
ax1.plot(test_losses, label="test")
ax1.set_title("Loss")

ax2.plot(train_accuracies, label="train")
ax2.plot(test_accuracies, label="test")
ax2.set_title("Accuracy")

plt.legend()
plt.show()

# save plot
fig.savefig(f"figures/unsup_DA_1_losses_accuracies.png")

In [None]:
# plot confusion matrix
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        y_true.append(labels.cpu().numpy())
        y_pred.append(predicted.cpu().numpy())
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    fig, ax = plt.subplots(figsize=(10, 10))
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot(ax=ax)
    plt.tight_layout()
    plt.show()

    # save plot
    fig.savefig(f"figures/unsup_DA_1_confusion_matrix.png")

In [None]:
# Evaluation on the test set
model.eval()  # Set the model to evaluation mode
test_correct = 0
test_total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # normalize
        images = normalize(data=images, mean=mean, std=std)
        
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100.0 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy}%')

# save the model
torch.save(model.state_dict(), f"{torch_models}/model_1_fixmatch_DA.pth")

test_image, test_labels = testloader.__iter__().__next__()
test_image = test_image.to(device)
outputs_test = model(test_image)
label_pred_test = outputs_test.argmax(dim=1)

# descale the images
test_image = test_image#  * torch.tensor(std, device=device).view(1, 3, 1, 1) + torch.tensor(mean, device=device).view(1, 3, 1, 1)

fig1 = plot_images(test_image, test_labels, label_pred_test, classes, figure_name=f"Test score with Fixmatch - {int(SUBSET_PROP*100)}% - {test_accuracy:.2f}% - Data Alignment")
fig1.savefig(f"./figures/test_score_1_fixmatch_DA.png")