In [31]:
# ----------------------------------------------------------------------------
# ----------------- Importing Libraries  -------------------------------------
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Normalize, Resize
from torchvision.utils import make_grid
from PIL import Image

# ----------------- FIXED VALUES ------------------------------------
RANDOM_SEED = 42


# MixUp


In [41]:
# ------------------------------ MIXUP / Define MixUp class ------------------------------
class MixUp(torch.nn.Module):
    def __init__(
        self,
        sampling_method: int = 1,
        num_classes: int = 10,
        alpha: float = 1.0,
        uni_range: list = [0.0, 1.0],
    ):
        super().__init__()
        self.sampling_method = sampling_method
        self.num_classes = num_classes
        self.alpha = alpha
        self.uni_range = uni_range

    def __call__(self, images, labels):
        # Set random seeds for reproducibility
        np.random.seed(RANDOM_SEED)
        torch.manual_seed(RANDOM_SEED)

        # Sample lambda for mixup based on the specified sampling method
        if self.sampling_method == 1:
            lam = np.random.beta(self.alpha, self.alpha)
        elif self.sampling_method == 2:
            lam = np.random.uniform(self.uni_range[0], self.uni_range[1])

        # Perform mixup on images
        index = torch.randperm(images.size(0))
        mixed_x = lam * images + (1 - lam) * images[index, :]

        # Convert labels to one-hot encoding and perform mixup
        y_one_hot = F.one_hot(labels, num_classes=self.num_classes).float()
        mixed_y = lam * y_one_hot + (1 - lam) * y_one_hot[index]

        return mixed_x, mixed_y


# Custom collate function that applies MixUp and retains original images
def collate_fn(batch):
    batch = default_collate(batch)
    images, labels = batch
    mixup = MixUp(sampling_method=1, num_classes=10, alpha=1.0, uni_range=[0.0, 1.0])
    mixed_x, mixed_y = mixup(images, labels)
    return mixed_x, images, mixed_y


# Data transformation
transform = transforms.Compose(
    [
        Resize((224, 224)),  # Resize images to fit the input size expected by MixUp
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

# CIFAR-10 training dataset and DataLoader
trainset = CIFAR10(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=32, shuffle=True, collate_fn=collate_fn
)


# Visualization function
def visualize_mixup(mixup, dataloader, num_images=16):
    # Set random seeds for reproducibility
    np.random.seed(RANDOM_SEED)
    torch.manual_seed(RANDOM_SEED)

    mixed_x, images, mixed_y = next(iter(dataloader))
    mixed_images, mixed_labels = mixup(images, mixed_y)

    # Select a subset of images for visualization
    mixed_images = mixed_images[:num_images]

    # Create a grid of images
    grid = make_grid(mixed_images, nrow=int(np.sqrt(num_images)))
    ndarr = (
        grid.mul(255)
        .add_(0.5)
        .clamp_(0, 255)
        .permute(1, 2, 0)
        .to("cpu", torch.uint8)
        .numpy()
    )
    img = Image.fromarray(ndarr)
    img.save("mixup.png")


# Usage example
mixup = MixUp(sampling_method=1, num_classes=10, alpha=1.0, uni_range=[0.0, 1.0])
visualize_mixup(mixup, trainloader)


Files already downloaded and verified


RuntimeError: one_hot is only applicable to index tensor.

In [33]:
# Define the MixUp class here (as previously discussed)


# # Visualization of MixUp-augmented and original images
# mixup.visualize(trainloader, num_images=16)

# import torchvision
# import numpy as np
# from torchvision.utils import make_grid
# from PIL import Image

# RANDOM_SEED = 42


# import torchvision
# import numpy as np
# from torchvision.utils import make_grid
# from PIL import Image

# RANDOM_SEED = 42


Files already downloaded and verified


ValueError: not enough values to unpack (expected 4, got 2)

In [35]:
# Usage example, assuming 'dataloader' is defined and provides CIFAR-10 or similar images
mixup = MixUp(sampling_method=1, num_classes=10, alpha=1.0, uni_range=[0.0, 1.0])
visualize_mixup(mixup, trainloader)


NameError: name 'lam' is not defined

Files already downloaded and verified
Montage of augmented images saved to 'mixup.png'


# ViT


In [2]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_32


class ViT(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.vit = vit_b_32(pretrained=True)

        # Freeze all layers in the pretrained model
        for param in self.vit.parameters():
            param.requires_grad = False

        # Replace the head with a new linear layer
        self.vit.heads.head = nn.Linear(self.vit.heads.head.in_features, num_classes)

    def forward(self, x):
        x = self.vit(x)
        return x


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.utils.prune as prune
from torchvision.transforms import Resize
from models import ViT
from data import MixUp

PRUNING_AMOUNT = 0.1


def apply_pruning(module, amount=PRUNING_AMOUNT):
    """Apply unstructured pruning based on the L1 norm of weights."""
    for m in module.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            prune.l1_unstructured(m, name="weight", amount=amount)


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_with_mixup(sampling_method, num_epochs=20):

    # Defining the data transformation for CIFAR-10
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),  # Resize images to 224x224 pixels
            transforms.ToTensor(),  # Convert images to PyTorch tensors
            transforms.Normalize(
                (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
            ),  # Normalize the images
        ]
    )

    # Load the CIFAR-10 dataset - train and test
    trainset = torchvision.datasets.CIFAR10(
        root="data", train=True, download=True, transform=transform
    )
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

    testset = torchvision.datasets.CIFAR10(
        root="data", train=False, download=True, transform=transform
    )
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize the model, loss function, and optimizer
    # Ensure the SimplifiedViT class is correctly initialized as per your modifications
    net = ViT().to(device)
    net.vit.heads.head.apply(initialize_weights)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        net.parameters(), lr=0.01, momentum=0.9
    )  # v2 - lr=0.001 brought very low results with SimplifiedViT v1 -> lr=0.01
    mixup = MixUp(alpha=1.0, sampling_method=sampling_method, seed=42)

    # v2 - Introduce a learning rate scheduler
    scheduler = lr_scheduler.StepLR(
        optimizer, step_size=5, gamma=0.1
    )  # Adjust learning rate every 5 epochs

    train_acc, test_acc = [], []  # Initialize accuracy lists

    for epoch in range(num_epochs):
        running_loss, correct, total = 0.0, 0, 0

        # Training loop
        net.train()  # Set the model to training mode
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            inputs, targets_a, targets_b, lam = mixup(inputs, labels)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)  # Get the predicted labels
            total += labels.size(0)
            correct += (
                (
                    lam * (predicted == targets_a).float()
                    + (1 - lam) * (predicted == targets_b).float()
                )
                .sum()
                .item()
            )

        # v4 - Prunning
        # Apply pruning at specified epochs and gradually increase the amount
        if epoch % 5 == 4:  # Example: Apply pruning every 5 epochs
            prune_amount = 0.05 + 0.05 * (
                epoch // 5
            )  # Increase pruning amount gradually
            apply_pruning(net, amount=prune_amount)
            print(f"Applied pruning with amount {prune_amount:.2f}")

        # v2 - Step the learning rate scheduler
        scheduler.step()

        train_acc.append(100 * correct / total)
        print(f"Epoch {epoch+1} - Training accuracy: {train_acc[-1]:.2f}%")

        # Test loop
        net.eval()  # Set the model to evaluation mode
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_acc.append(100 * correct / total)
        print(f"Epoch {epoch+1} - Test accuracy: {test_acc[-1]:.2f}%")

    # Save the trained model
    model_path = os.path.join(".", f"model_sampling_{sampling_method}.pth")
    torch.save(net.state_dict(), model_path)
    print(f"Model with sampling method {sampling_method} saved to {model_path}")

    return train_acc, test_acc


if __name__ == "__main__":
    print("Training with sampling method 1 (beta distribution)")
    train_acc_1, test_acc_1 = train_with_mixup(sampling_method=1)

    print("Training with sampling method 2 (uniform distribution)")
    train_acc_2, test_acc_2 = train_with_mixup(sampling_method=2)

    # Report test set performance
    print("Test set performance for sampling method 1:")
    for epoch, acc in enumerate(test_acc_1):
        print(f"Epoch {epoch+1} - Test accuracy: {acc:.2f}%")

    print("Test set performance for sampling method 2:")
    for epoch, acc in enumerate(test_acc_2):
        print(f"Epoch {epoch+1} - Test accuracy: {acc:.2f}%")


In [None]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image, ImageDraw, ImageFont
from models import ViT


def visualize_results(model_path, testloader, classes, num_images=36):
    # Load the trained model
    net = ViT()
    net.load_state_dict(torch.load(model_path))
    net.eval()

    # Get a batch of test images
    dataiter = iter(testloader)
    images, labels = next(dataiter)

    # Make predictions on the test images
    images = images.cuda()
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)

    # Create a montage of the test images with labels
    montage = make_grid(images[:num_images], nrow=6, padding=2).cpu()
    montage_image = transforms.ToPILImage()(montage)

    # Add labels to the montage
    draw = ImageDraw.Draw(montage_image)
    font = ImageFont.truetype("arial.ttf", 12)

    for i in range(num_images):
        x = i % 6 * montage_image.width // 6 + 5
        y = i // 6 * montage_image.height // 6 + 5
        label_text = f"Truth: {classes[labels[i]]}\nPredicted: {classes[predicted[i]]}"
        draw.text((x, y), label_text, font=font, fill="black")

    # Save the montage as "result.png"
    result_path = os.path.join(os.path.dirname(model_path), "result.png")
    montage_image.save(result_path)
    print(f"Montage of test images with labels saved to {result_path}")
