In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

In [2]:
IMG_SIZE = 299
BATCH_SIZE = 32

In [3]:
# Data transforms
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
# Data loaders
train_dataset = datasets.ImageFolder('./data/train/', transform=train_transform)
val_dataset = datasets.ImageFolder('./data/val/', transform=val_transform)
test_dataset = datasets.ImageFolder('./data/test/', transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [5]:
def squash(x, dim=-1):
    squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * x / torch.sqrt(squared_norm + 1e-8)

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels, routing_iters=3):
        super().__init__()
        self.num_capsules = num_capsules
        self.routing_iters = routing_iters
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        # Initialize transformation matrix
        self.W = nn.Parameter(torch.randn(1, num_capsules, in_channels, out_channels) * 0.01)

    def forward(self, x):
        batch_size = x.size(0)
        
        # x shape: [batch_size, in_channels, num_points]
        # Reshape x to [batch_size, 1, num_points, in_channels]
        x = x.permute(0, 2, 1).unsqueeze(1)
        
        # Expand W to match batch size
        W = self.W.expand(batch_size, self.num_capsules, self.in_channels, self.out_channels)
        
        # Calculate u_hat (predicted output vectors)
        # [batch_size, num_capsules, num_points, out_channels]
        u_hat = torch.matmul(x, W)
        
        # Initialize routing logits
        b = torch.zeros(batch_size, self.num_capsules, x.size(2), 1).to(x.device)
        
        # Routing algorithm
        for i in range(self.routing_iters):
            c = F.softmax(b, dim=1)
            
            # Calculate weighted sum
            s = (c * u_hat).sum(dim=2)
            v = squash(s, dim=-1)
            
            if i < self.routing_iters - 1:
                # Update routing logits
                v_expanded = v.unsqueeze(2)
                agreement = torch.matmul(u_hat, v_expanded.transpose(-1, -2))
                b = b + agreement

        return v


In [6]:
class CapsuleNetwork(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        
        # Load pretrained VGG16
        vgg = models.vgg16(weights='DEFAULT')
        self.features = nn.Sequential(*list(vgg.features.children())[:-1])
        
        # Freeze VGG layers
        for param in self.features.parameters():
            param.requires_grad = False
            
        self.conv = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        
        # Primary capsules: Convert conv features to capsules
        self.primary_capsules = CapsuleLayer(
            num_capsules=32,
            in_channels=256,
            out_channels=8
        )
        
        # Digit capsules: Final classification capsules
        self.digit_capsules = CapsuleLayer(
            num_capsules=num_classes,
            in_channels=8,
            out_channels=16
        )
        
    def forward(self, x):
        x = self.features(x)
        x = F.relu(self.conv(x))
        
        # Reshape for primary capsules
        x = x.view(x.size(0), x.size(1), -1)  # [batch_size, channels, height*width]
        
        # Primary capsules
        x = self.primary_capsules(x)
        
        # Prepare input for digit capsules
        x = x.transpose(1, 2)  # Adjust dimensions for digit capsules
        
        # Digit capsules
        x = self.digit_capsules(x)
        
        # Calculate lengths of the output capsules
        classes = torch.sqrt((x ** 2).sum(dim=-1))
        return classes

In [7]:
# Margin loss function
def margin_loss(predictions, labels, lambda_=0.5, m_plus=0.9, m_minus=0.1):
    positive_cost = labels * torch.clamp(m_plus - predictions, min=0) ** 2
    negative_cost = lambda_ * (1 - labels) * torch.clamp(predictions - m_minus, min=0) ** 2
    return (positive_cost + negative_cost).sum(dim=1).mean()

# Training setup
model = CapsuleNetwork().cuda()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)

In [8]:
# Training loop
def train_epoch(model, train_loader, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        target = F.one_hot(target, num_classes=2).float()
        
        optimizer.zero_grad()
        output = model(data)
        loss = margin_loss(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target.argmax(dim=1)).sum().item()
        total += target.size(0)
        
    return total_loss / len(train_loader), correct / total


In [9]:
# Validation function
def validate(model, val_loader):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.cuda(), target.cuda()
            target = F.one_hot(target, num_classes=2).float()
            
            output = model(data)
            val_loss += margin_loss(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target.argmax(dim=1)).sum().item()
            total += target.size(0)
            
    return val_loss / len(val_loader), correct / total


In [10]:
# Training loop
num_epochs = 100
best_val_acc = 0

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer)
    val_loss, val_acc = validate(model, val_loader)
    
    scheduler.step(val_loss)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
    
    print(f'Epoch: {epoch}')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')


Epoch: 0
Train Loss: 0.0940, Train Acc: 0.8570
Val Loss: 0.1080, Val Acc: 0.8750


KeyboardInterrupt: 

In [None]:
# Test the model
model.load_state_dict(torch.load('best_model.pth'))
test_loss, test_acc = validate(model, test_loader)
print(f'\nTest accuracy: {test_acc:.4f}')