In [1]:
# !pip install lmdb

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, auc
import os

In [3]:
print('🕺 lets gooooooooooooooooooooooooooo' if torch.cuda.is_available() else '💩 f@*# the cpu')

🕺 lets gooooooooooooooooooooooooooo


In [15]:
class Length(nn.Module):
    def forward(self, inputs):
        return torch.sqrt((inputs ** 2).sum(dim=-1))

class Mask(nn.Module):
    def forward(self, inputs):
        if isinstance(inputs, list):
            inputs, mask = inputs
        else:
            x = torch.sqrt((inputs ** 2).sum(dim=-1))
            mask = torch.nn.functional.one_hot(x.argmax(dim=1), num_classes=x.size(1))
        inputs_masked = torch.einsum('bij,bj->bi', inputs, mask)
        return inputs_masked

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsule=16, dim_vector=64, num_routing=3):
        super(CapsuleLayer, self).__init__()
        self.num_capsule = num_capsule
        self.dim_vector = dim_vector
        self.num_routing = num_routing

    def forward(self, x):
        batch_size = x.size(0)
        u_hat = x.view(batch_size, -1, 1, self.dim_vector)
        b_ij = torch.zeros(1, u_hat.size(1), self.num_capsule, 1)
        for i in range(self.num_routing):
            c_ij = torch.nn.functional.softmax(b_ij, dim=2)
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = squash(s_j)
            if i < self.num_routing - 1:
                b_ij = b_ij + (u_hat * v_j).sum(dim=-1, keepdim=True)
        return v_j.squeeze()

def squash(tensor, dim=-1):
    squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * tensor / torch.sqrt(squared_norm + 1e-9)

class PyramidBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PyramidBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.ln1 = nn.LayerNorm([out_channels, 128, 128]) 
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.ln2 = nn.LayerNorm([out_channels, 64, 64]) 
        self.dropout = nn.Dropout(0.1)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x, residual):
        x = torch.cat([x, residual], dim=1)
        x = torch.relu(self.ln1(self.conv1(x)))
        x = self.pool(x)
        residual = x
        x = torch.relu(self.ln2(self.conv2(x)))
        x = self.dropout(x)
        x = self.pool(x)
        return x, residual

class LaplacianNet(nn.Module):
    def __init__(self, input_shape, num_classes, num_routing=3):
        super(LaplacianNet, self).__init__()
        self.pyramid_block1 = PyramidBlock(input_shape[0] * 2, 64)
        self.pyramid_block2 = PyramidBlock(64 + input_shape[0] * 2, 128)
        self.pyramid_block3 = PyramidBlock(128 + 64 + input_shape[0] * 2, 256)
        
        self.primary_caps = nn.Conv2d(256, 32 * 8, kernel_size=11, stride=1, padding=0)
        self.primary_caps_reshape = nn.Sequential(
            nn.Conv2d(32 * 8, 32 * 8, kernel_size=1),
            nn.BatchNorm2d(32 * 8),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(32 * 8 * 2 * 2, 32 * 8)
        )
        self.capsule_layer = CapsuleLayer(num_capsule=num_classes, dim_vector=8, num_routing=num_routing)
        self.length = Length()
        self.decoder = nn.Sequential(
            nn.Linear(8 * num_classes, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, np.prod(input_shape)),
            nn.Sigmoid(),
            nn.Unflatten(1, input_shape)
        )

    def forward(self, inputs, y=None):
        x_g1, x_g2, x_g3, x_g4, x_l1, x_l2, x_l3, x_l4 = inputs
        
        x1, r1 = self.pyramid_block1(x_g1, x_l1)
        x2, r2 = self.pyramid_block2(x1, torch.cat([x_g2, x_l2], dim=1))
        x3, _ = self.pyramid_block3(x2, torch.cat([x_g3, x_l3], dim=1))
        x = torch.cat([x3, x_g4, x_l4], dim=1)
        
        x = self.primary_caps(x)
        x = self.primary_caps_reshape(x)
        x = self.capsule_layer(x)
        
        if y is not None:
            masked = Mask()([x, y])
        else:
            masked = Mask()(x)
            
        decoded = self.decoder(masked)
        length = self.length(x)
        return length, decoded

def margin_loss(y_true, y_pred):
    alpha_margin = torch.clamp(0.9 - y_pred, min=0.0) ** 2
    beta_margin = torch.clamp(y_pred - 0.1, min=0.0) ** 2
    L = y_true * alpha_margin + 0.5 * (1 - y_true) * beta_margin
    return L.mean()

class ImagePyramid:
    def __init__(self, data, **kw):
        self.data = data
        self.name = kw.get("name", "imgpyr")
        self.verbose = kw.get("verbose", True)
        self.num_classes = kw.get("num_classes", 16)
        self.input_shape = kw.get("input_shape", (1, 128, 128))
        self.batch_size = kw.get("batch_size", 32)
        self.recon_loss = kw.get("recon_loss", nn.L1Loss())
        self.watch = kw.get("monitor", "val_loss")
        self.wait = kw.get("patience", 10)
        self.path = kw.get("path", f"weights/{self.name}_weights.pth")
        self.__get_model__()

    def __get_model__(self):
        self.model = LaplacianNet(self.input_shape, self.num_classes)
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4, amsgrad=True)
        self.criterion = margin_loss
        self.__setup_callbacks__()

    def __setup_callbacks__(self):
        self.early_stop = None  # Implement early stopping if needed
        self.checkpoint = self.path
        self.reduce_on_plateau = None  # Implement LR reduction if needed

    def fit(self):
        dataloader = DataLoader(self.data, batch_size=self.batch_size, shuffle=True)
        best_loss = float('inf')
        patience_counter = 0

        for epoch in range(1000):
            self.model.train()
            train_loss = 0.0
            for data in dataloader:
                inputs, y = data
                self.optimizer.zero_grad()
                output, reconstructed = self.model(inputs, y)
                loss = self.criterion(y, output) + self.recon_loss(reconstructed, inputs[0])
                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()

            train_loss /= len(dataloader)
            print(f'Epoch {epoch}, Loss: {train_loss}')

            if train_loss < best_loss:
                best_loss = train_loss
                torch.save(self.model.state_dict(), self.checkpoint)
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter > self.wait:
                print("Early stopping triggered")
                break

    def evaluate(self, test_data):
        test_dataloader = DataLoader(test_data, batch_size=self.batch_size, shuffle=False)
        self.model.load_state_dict(torch.load(self.checkpoint))
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for data in test_dataloader:
                inputs, y = data
                output, _ = self.model(inputs)
                pred = output.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        accuracy = 100 * correct / total
        print(f'Accuracy: {accuracy}%')

In [5]:
class ImageDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.to_tensor = transforms.ToTensor()

    def upscale(self, img, target_size):
        return transforms.functional.resize(img, target_size)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]

        # Gaussian pyramid levels as PIL Images
        g1 = transforms.ToPILImage()(img)
        g2 = transforms.Resize((64, 64))(g1)
        g3 = transforms.Resize((32, 32))(g1)
        g4 = transforms.Resize((16, 16))(g1)

        # Convert back to tensors
        g1 = self.to_tensor(g1)
        g2 = self.to_tensor(g2)
        g3 = self.to_tensor(g3)
        g4 = self.to_tensor(g4)

        # Laplacian pyramid differences
        l1 = g1
        l2 = transforms.Resize((64, 64))(g1 - self.upscale(g2, g1.shape[-2:]))
        l3 = transforms.Resize((32, 32))(g2 - self.upscale(g3, g2.shape[-2:]))
        l4 = transforms.Resize((16, 16))(g3 - self.upscale(g4, g3.shape[-2:]))

        return [g1, g2, g3, g4, l1, l2, l3, l4], label

def plot_curves(train_loss, val_loss, dataset_name):
    plt.figure()
    plt.plot(train_loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Training and Validation Loss Curves for {dataset_name}')
    plt.savefig(f'results/{dataset_name}_loss_curve.png')
    plt.close()

def plot_confusion_matrix(cm, dataset_name, phase):
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix for {dataset_name} - {phase}')
    plt.savefig(f'results/{dataset_name}_confusion_matrix_{phase}.png')
    plt.close()

def plot_roc_auc(y_true, y_score, dataset_name):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic for {dataset_name}')
    plt.legend(loc='lower right')
    plt.savefig(f'results/{dataset_name}_roc_auc.png')
    plt.close()


In [6]:
os.makedirs('results', exist_ok=True)
os.makedirs('weights', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

datasets = {
    'CIFAR10': torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform),
    'CIFAR100': torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform),
    'MNIST': torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform),
    'FashionMNIST': torchvision.datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform),
    'SVHN': torchvision.datasets.SVHN(root='../data', split='train', download=True, transform=transform),
    'STL10': torchvision.datasets.STL10(root='../data', split='train', download=True, transform=transform),
    # 'ImageNet': torchvision.datasets.ImageNet(root='../data', split='train', download=True, transform=transform),
    'Caltech101': torchvision.datasets.Caltech101(root='../data', download=True, transform=transform),
    'Caltech256': torchvision.datasets.Caltech256(root='../data', download=True, transform=transform),
    'CelebA': torchvision.datasets.CelebA(root='../data', split='train', download=True, transform=transform),
    # 'LSUN': torchvision.datasets.LSUN(root='../data', classes='train', transform=transform,),
    'Omniglot': torchvision.datasets.Omniglot(root='../data', download=True, transform=transform),
    'OxfordIIITPet': torchvision.datasets.OxfordIIITPet(root='../data', split='trainval', download=True, transform=transform),
    # 'StanfordCars': torchvision.datasets.StanfordCars(root='../data', split='train', download=True, transform=transform),
    'SBD': torchvision.datasets.SBDataset(root='../data', image_set='train', download=False)
}

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ../data/train_32x32.mat
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [16]:
def train_and_evaluate(dataset_name, train_data, test_data, device):
    train_loader = DataLoader(ImageDataset(train_data), batch_size=32, shuffle=True)
    test_loader = DataLoader(ImageDataset(test_data), batch_size=32, shuffle=False)

    num_classes = len(train_data.classes)
    input_shape = train_data[0][0].shape
    print(input_shape)
    model = LaplacianNet(input_shape, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)
    criterion = margin_loss

    best_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []

    for epoch in range(100):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs = [inp.to(device) for inp in inputs]
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs, reconstructed = model(inputs, labels)
            loss = criterion(labels, outputs) + nn.L1Loss()(reconstructed, inputs[0])
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        model.eval()
        val_loss = 0.0
        y_true = []
        y_pred = []
        y_score = []
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = [inp.to(device) for inp in inputs]
                labels = labels.to(device)
                outputs, _ = model(inputs)
                loss = criterion(labels, outputs)
                val_loss += loss.item()

                y_true.extend(labels.cpu().numpy())
                y_pred.extend(outputs.argmax(dim=1).cpu().numpy())
                y_score.extend(outputs.cpu().numpy())

        val_loss /= len(test_loader)
        val_losses.append(val_loss)

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), f'weights/{dataset_name}_best.pth')
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter > 10:  # patience
            print("Early stopping triggered")
            break

        print(f'Epoch {epoch + 1}/{100}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

    plot_curves(train_losses, val_losses, dataset_name)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_score = np.array(y_score)

    cm_train = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm_train, dataset_name, 'train')

    roc_auc = roc_auc_score(y_true, y_score, multi_class='ovr')
    plot_roc_auc(y_true, y_score, dataset_name)

    return model

In [17]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

for name, dataset in datasets.items():
    print(f"Training on {name} dataset")
    train_data = dataset
    test_data = torchvision.datasets.__dict__[name](root='./data', train=False, download=True, transform=transform)
    model = train_and_evaluate(name, train_data, test_data, device)

Training on CIFAR10 dataset
Files already downloaded and verified
torch.Size([3, 128, 128])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 64 but got size 128 for tensor number 1 in the list.