In [1]:
!pip install einops

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
Installing collected packages: einops
Successfully installed einops-0.6.1




In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from einops.layers.torch import Rearrange

# Hyperparameters
batch_size = 16  # Reduced batch size
learning_rate = 0.001
num_epochs = 10

# MNIST dataset and DataLoader
transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor()])
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=batch_size, shuffle=True)



# Vision Transformer model
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, d_model, num_heads, num_layers):
        super(ViT, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        self.patch_embed = nn.Sequential(
            nn.Conv2d(1, d_model, kernel_size=patch_size, stride=patch_size),
        )
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads),
            num_layers=num_layers
        )
        self.classification_head = nn.Sequential(
            nn.Linear(d_model, num_classes),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)  # Global average pooling
        x = self.classification_head(x)
        return x




# Model parameters
cnn_channels = 16
image_size = 32
patch_size = 8
num_classes = 10
d_model = 64
num_heads = 4
num_layers = 4



class CombinedModel(nn.Module):
    def __init__(self, cnn_channels, image_size, patch_size, num_classes, d_model, num_heads, num_layers):
        super(CombinedModel, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, cnn_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.vit = ViT(image_size=image_size // patch_size,
                       patch_size=1,
                       num_classes=num_classes,
                       d_model=d_model,
                       num_heads=num_heads,
                       num_layers=num_layers)

        # Calculate the CNN output size for the linear layer's input size
        cnn_out_shape = self.cnn(torch.randn(1, 1, image_size, image_size)).shape
        cnn_features_size = cnn_channels * (cnn_out_shape[2] // 2) * (cnn_out_shape[3] // 2)

        self.classification_head = nn.Sequential(
            nn.Linear(cnn_features_size + d_model, num_classes)
        )

    def forward(self, x):
        cnn_out = self.cnn(x)
        cnn_global_avg = nn.functional.adaptive_avg_pool2d(cnn_out, (1, 1)).view(cnn_out.size(0), -1)
        vit_out = self.vit(x)
        combined_features = torch.cat((cnn_global_avg, vit_out), dim=1)
        logits = self.classification_head(combined_features)
        return logits



# Initialize the model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = CombinedModel(cnn_channels=cnn_channels, image_size=image_size, patch_size=patch_size,
                      num_classes=num_classes, d_model=d_model, num_heads=num_heads, num_layers=num_layers).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Gradient accumulation steps
accumulation_steps = 4
total_loss = 0

# Lists to store training accuracy values
train_accuracies = []

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_train_correct = 0
    total_train_samples = 0
    
    for batch_idx, (images, labels) in enumerate(data_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        total_loss += loss.item()
        
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()  # Reset gradients
            total_loss /= accumulation_steps
            total_loss = 0
        
        _, predicted = torch.max(outputs, 1)
        total_train_samples += labels.size(0)
        total_train_correct += (predicted == labels).sum().item()
    
    train_accuracy = total_train_correct / total_train_samples
    train_accuracies.append(train_accuracy)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Train Acc: {train_accuracy:.4f}')

print("Training Finished!")

OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 4.00 GiB total capacity; 3.42 GiB already allocated; 0 bytes free; 3.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF