In [None]:
!pip install einops



In [None]:
import einops
import torch.nn as nn

In [None]:
class MlpBlock(nn.Module):


    def __init__(self, dim, mlp_dim=None):
        super().__init__()

        mlp_dim = dim if mlp_dim is None else mlp_dim
        self.linear_1 = nn.Linear(dim, mlp_dim)
        self.activation = nn.GELU()
        self.linear_2 = nn.Linear(mlp_dim, dim)

    def forward(self, x):

        x = self.linear_1(x)  # (n_samples, *, mlp_dim)
        x = self.activation(x)  # (n_samples, *, mlp_dim)
        x = self.linear_2(x)  # (n_samples, *, dim)
        return x


In [None]:
class MixerBlock(nn.Module):


    def __init__(
        self, *, n_patches, hidden_dim, tokens_mlp_dim, channels_mlp_dim
    ):
        super().__init__()

        self.norm_1 = nn.LayerNorm(hidden_dim)
        self.norm_2 = nn.LayerNorm(hidden_dim)

        self.token_mlp_block = MlpBlock(n_patches, tokens_mlp_dim)
        self.channel_mlp_block = MlpBlock(hidden_dim, channels_mlp_dim)

    def forward(self, x):

        y = self.norm_1(x)  # (n_samples, n_patches, hidden_dim)
        y = y.permute(0, 2, 1)  # (n_samples, hidden_dim, n_patches)
        y = self.token_mlp_block(y)  # (n_samples, hidden_dim, n_patches)
        y = y.permute(0, 2, 1)  # (n_samples, n_patches, hidden_dim)
        x = x + y  # (n_samples, n_patches, hidden_dim)
        y = self.norm_2(x)  # (n_samples, n_patches, hidden_dim)
        res = x + self.channel_mlp_block(
            y
        )  # (n_samples, n_patches, hidden_dim)
        return res

In [None]:
class MlpMixer(nn.Module):

    def __init__(
        self,
        *,
        image_size,
        patch_size,
        tokens_mlp_dim,
        channels_mlp_dim,
        n_classes,
        hidden_dim,
        n_blocks,
    ):
        super().__init__()
        n_patches = (image_size // patch_size) ** 2

        self.patch_embedder = nn.Conv2d(
            3,
            hidden_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )
        self.blocks = nn.ModuleList(
            [
                MixerBlock(
                    n_patches=n_patches,
                    hidden_dim=hidden_dim,
                    tokens_mlp_dim=tokens_mlp_dim,
                    channels_mlp_dim=channels_mlp_dim,
                )
                for _ in range(n_blocks)
            ]
        )

        self.pre_head_norm = nn.LayerNorm(hidden_dim)
        self.head_classifier = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):

        x = self.patch_embedder(
            x
        )  # (n_samples, hidden_dim, n_patches ** (1/2), n_patches ** (1/2))
        x = einops.rearrange(
            x, "n c h w -> n (h w) c"
        )  # (n_samples, n_patches, hidden_dim)
        for mixer_block in self.blocks:
            x = mixer_block(x)  # (n_samples, n_patches, hidden_dim)

        x = self.pre_head_norm(x)  # (n_samples, n_patches, hidden_dim)
        x = x.mean(dim=1)  # (n_samples, hidden_dim)
        y = self.head_classifier(x)  # (n_samples, n_classes)

        return y

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm


batch_size = 64
learning_rate = 0.001
num_epochs = 10
image_size = 32  # CIFAR-10 image size
patch_size = 4   # Patch size for embedding
n_classes = 10   # CIFAR-10 classes
hidden_dim = 128 # Hidden dimension of patch embeddings
tokens_mlp_dim = 128  # Hidden dimension for token mixing
channels_mlp_dim = 128 # Hidden dimension for channel mixing
n_blocks = 8  # Number of Mixer blocks

# Data transformation and loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize the images
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False
)


def train(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

# def evaluate(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             outputs = model(images)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     return correct / total

def evaluate(model, test_loader):
    model.eval()
    predictions = []
    correct_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())
            correct_predictions.extend((predicted == labels).cpu().numpy())
    return predictions, correct_predictions

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 52.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np

num_runs = 2  # Number of random initializations
results = []
models = []

for run in range(num_runs):
    print(f"Run {run + 1}/{num_runs}")

    # Set a different random seed for each run
    torch.manual_seed(run)
    np.random.seed(run)

    # Initialize the model, criterion, and optimizer
    model = MlpMixer(
        image_size=image_size,
        patch_size=patch_size,
        tokens_mlp_dim=tokens_mlp_dim,
        channels_mlp_dim=channels_mlp_dim,
        n_classes=n_classes,
        hidden_dim=hidden_dim,
        n_blocks=n_blocks
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train and evaluate the model
#     for epoch in range(num_epochs):
#         train_loss = train(model, train_loader, criterion, optimizer)
#         test_acc = evaluate(model, test_loader)
#         print(f"Epoch {epoch + 1}/{num_epochs},  Train Loss: {train_loss:.4f} Test Accuracy: {test_acc:.4f}")

#     # Store the final test accuracy for this run
#     results.append(test_acc)


    for epoch in range(num_epochs):
      train_loss = train(model, train_loader, criterion, optimizer)
      print(f"Epoch {epoch + 1}/{num_epochs},  Train Loss: {train_loss:.4f}")
  # print("Resulting Test Accuracies: ", results)
    models.append(model)

all_predictions = []
correct_predictions_list = []
for model in models:
    preds, corrects = evaluate(model, test_loader)
    all_predictions.append(preds)
    correct_predictions_list.append(corrects)

print(all_predictions)
print(correct_predictions_list)


Run 1/2
cuda


100%|██████████| 782/782 [00:24<00:00, 32.11it/s]


Epoch 1/10,  Train Loss: 1.5238


100%|██████████| 782/782 [00:24<00:00, 32.36it/s]


Epoch 2/10,  Train Loss: 1.1823


100%|██████████| 782/782 [00:23<00:00, 32.67it/s]


Epoch 3/10,  Train Loss: 1.0325


100%|██████████| 782/782 [00:24<00:00, 31.94it/s]


Epoch 4/10,  Train Loss: 0.9204


100%|██████████| 782/782 [00:24<00:00, 32.22it/s]


Epoch 5/10,  Train Loss: 0.8364


100%|██████████| 782/782 [00:24<00:00, 32.10it/s]


Epoch 6/10,  Train Loss: 0.7575


100%|██████████| 782/782 [00:24<00:00, 31.61it/s]


Epoch 7/10,  Train Loss: 0.6843


100%|██████████| 782/782 [00:24<00:00, 31.98it/s]


Epoch 8/10,  Train Loss: 0.6099


100%|██████████| 782/782 [00:24<00:00, 31.75it/s]


Epoch 9/10,  Train Loss: 0.5418


100%|██████████| 782/782 [00:24<00:00, 31.89it/s]


Epoch 10/10,  Train Loss: 0.4735
Run 2/2
cuda


100%|██████████| 782/782 [00:24<00:00, 32.00it/s]


Epoch 1/10,  Train Loss: 1.5050


100%|██████████| 782/782 [00:24<00:00, 32.27it/s]


Epoch 2/10,  Train Loss: 1.1512


100%|██████████| 782/782 [00:24<00:00, 31.78it/s]


Epoch 3/10,  Train Loss: 0.9890


100%|██████████| 782/782 [00:24<00:00, 32.00it/s]


Epoch 4/10,  Train Loss: 0.8787


100%|██████████| 782/782 [00:24<00:00, 31.86it/s]


Epoch 5/10,  Train Loss: 0.7949


100%|██████████| 782/782 [00:24<00:00, 31.99it/s]


Epoch 6/10,  Train Loss: 0.7202


100%|██████████| 782/782 [00:24<00:00, 31.59it/s]


Epoch 7/10,  Train Loss: 0.6479


100%|██████████| 782/782 [00:24<00:00, 32.04it/s]


Epoch 8/10,  Train Loss: 0.5798


100%|██████████| 782/782 [00:24<00:00, 31.88it/s]


Epoch 9/10,  Train Loss: 0.5099


100%|██████████| 782/782 [00:24<00:00, 31.47it/s]


Epoch 10/10,  Train Loss: 0.4436
[[3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 3, 7, 8, 6, 7, 0, 2, 9, 5, 2, 4, 0, 9, 6, 3, 4, 4, 3, 9, 9, 7, 9, 9, 5, 4, 6, 5, 6, 0, 9, 5, 9, 7, 6, 9, 8, 6, 6, 8, 8, 7, 3, 5, 5, 7, 5, 6, 1, 6, 2, 1, 2, 3, 7, 2, 5, 8, 8, 0, 2, 1, 3, 5, 8, 8, 1, 1, 7, 2, 7, 7, 2, 8, 8, 8, 3, 8, 6, 4, 6, 6, 0, 0, 7, 4, 4, 6, 3, 1, 1, 5, 6, 8, 7, 5, 0, 2, 2, 1, 5, 0, 5, 3, 5, 8, 7, 1, 2, 8, 2, 8, 5, 3, 0, 4, 1, 9, 9, 1, 3, 9, 7, 2, 8, 3, 5, 6, 5, 8, 0, 3, 6, 5, 3, 8, 9, 6, 9, 0, 3, 2, 9, 3, 4, 2, 1, 5, 6, 0, 4, 8, 4, 5, 4, 9, 0, 9, 8, 9, 9, 3, 7, 3, 0, 0, 5, 2, 6, 6, 8, 2, 6, 3, 8, 5, 8, 0, 1, 7, 2, 8, 8, 3, 8, 3, 0, 0, 7, 1, 5, 0, 5, 7, 0, 6, 8, 5, 5, 8, 0, 4, 9, 0, 7, 7, 3, 9, 5, 9, 9, 3, 4, 9, 9, 5, 1, 5, 1, 8, 0, 4, 2, 6, 5, 1, 1, 0, 9, 0, 2, 1, 8, 2, 0, 5, 3, 9, 9, 4, 8, 3, 0, 8, 9, 8, 1, 0, 3, 0, 0, 2, 4, 7, 0, 2, 4, 6, 5, 8, 0, 0, 2, 4, 7, 9, 0, 6, 1, 9, 9, 0, 0, 7, 9, 1, 2, 6, 1, 5, 2, 6, 0, 0, 6, 6, 6, 5, 8, 6, 0, 8, 2, 1, 4, 8, 6, 0, 3, 4, 0, 2, 7, 5, 5, 5, 5,

In [None]:
import numpy as np

all_predictions = np.array(all_predictions)
correct_predictions_list = np.array(correct_predictions_list)

# Check for agreement and correct classification
agreement_and_correct = np.all(correct_predictions_list, axis=0) & np.all(
    all_predictions == all_predictions[0, :], axis=0
)

# Count how many images are agreed upon and correctly classified
num_agreed_correct = np.sum(agreement_and_correct)
total_images = len(agreement_and_correct)

print(f"Agreement and correct classification on {num_agreed_correct}/{total_images} images.")

# Find where all models are incorrect
all_incorrect = np.all(correct_predictions_list == 0, axis=0)

# Find where all models made the same incorrect prediction
same_incorrect_prediction = all_incorrect & np.all(all_predictions == all_predictions[0, :], axis=0)

# Count the number of such cases
num_same_misclassifications = np.sum(same_incorrect_prediction)

print(f"Agreement on incorrect classication on {num_same_misclassifications}/{total_images} images.")

Agreement and correct classification on 5753/10000 images.
Agreement on incorrect classication on 1186/10000 images.
