In [1]:
import torch
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"CUDA Device Count: {torch.cuda.device_count()}")
print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")


CUDA Available: True
CUDA Device Count: 1
CUDA Device Name: Quadro P2000


In [None]:
import os
import torch.nn as nn
from einops.layers.torch import Rearrange
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchsummary import summary

class MixerBlock(nn.Module):
    def __init__(self, dim, num_patches):
        super().__init__()
        self.pre_layer_norm = nn.LayerNorm(dim)
        self.post_layer_norm = nn.LayerNorm(dim)
        
        self.token_mixer = nn.Sequential(
            nn.Linear(num_patches, dim),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(dim, num_patches),
            nn.Dropout(0.3)
        )
        
        self.channel_mixer = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(dim, dim),
            nn.Dropout(0.3)
        )
        
    def forward(self, x):
        pre_ln = self.pre_layer_norm(x)
        tm_out = self.token_mixer(pre_ln.transpose(1, 2)).transpose(1, 2)
        tm_out = tm_out + x
        post_ln = self.post_layer_norm(tm_out)
        cm_out = self.channel_mixer(post_ln) + tm_out
        return cm_out

class MLPMixer(nn.Module):
    def __init__(self, input_size, patch_size, dim=512, img_channel=3, layers=12, num_classes=2): 
        super().__init__()
        assert (input_size[0] % patch_size[0]) == 0, 'H must be divisible by patch size'
        assert (input_size[1] % patch_size[1]) == 0, 'W must be divisible by patch size'
        
        num_patches = int(input_size[0] / patch_size[0] * input_size[1] / patch_size[1])
        patch_dim = img_channel * patch_size[0] * patch_size[1]
        
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size[0], p2=patch_size[1]),
            nn.Linear(patch_dim, dim)
        )
         
        self.network = nn.Sequential(*[MixerBlock(dim, num_patches) for _ in range(layers)])
        
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(dim, num_classes)
    
    def forward(self, x):
        x = self.to_patch_embedding(x)
        x = self.network(x)
        return self.classifier(self.pool(x.transpose(1, 2)).squeeze(2))

def get_data_loader(data_dir, batch_size=32, image_size=(256, 256)):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])
    
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    avg_loss = running_loss / len(train_loader)
    accuracy = correct / total
    return avg_loss, accuracy

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
    
    return running_loss / len(val_loader), correct / len(val_loader.dataset)

def get_test_data_loader(data_dir, batch_size=32, image_size=(256, 256)):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])
    
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total
    return accuracy

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = MLPMixer(
        input_size=(256, 256),
        patch_size=(16, 16),
        dim=512,
        layers=12,
        num_classes=2  # Binary classification
    ).to(device)
    
    summary(model, input_size=(3, 256, 256))

    train_loader = get_data_loader('./data/split_BC-15/train', batch_size=32)
    val_loader = get_data_loader('./data/split_BC-15/validation', batch_size=32)
    test_loader = get_test_data_loader('./data/split_BC-15/test', batch_size=32)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    best_val_loss = float('inf')
    patience = 10
    early_stop_counter = 0

    for epoch in range(100):  # Increased number of epochs
        train_loss, train_accuracy = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy = validate(model, val_loader, criterion, device)
        
        print(f'Epoch [{epoch+1}/100], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        
        # Check if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            # Save model if validation loss improves
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            early_stop_counter += 1
            
        # Early stopping check
        if early_stop_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load the best model for testing
    model.load_state_dict(torch.load('best_model.pth'))
    test_accuracy = evaluate(model, test_loader, device)
    print(f'Test Accuracy: {test_accuracy:.4f}')

if __name__ == '__main__':
    main()


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1             [-1, 256, 768]               0
            Linear-2             [-1, 256, 512]         393,728
         LayerNorm-3             [-1, 256, 512]           1,024
            Linear-4             [-1, 512, 512]         131,584
              GELU-5             [-1, 512, 512]               0
           Dropout-6             [-1, 512, 512]               0
            Linear-7             [-1, 512, 256]         131,328
           Dropout-8             [-1, 512, 256]               0
         LayerNorm-9             [-1, 256, 512]           1,024
           Linear-10             [-1, 256, 512]         262,656
             GELU-11             [-1, 256, 512]               0
          Dropout-12             [-1, 256, 512]               0
           Linear-13             [-1, 256, 512]         262,656
          Dropout-14             [-1, 2

  model.load_state_dict(torch.load('best_model.pth'))


Test Accuracy: 0.9198
