In [None]:
!pip install pynvml

Collecting pynvml
  Downloading pynvml-11.5.0-py3-none-any.whl (53 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pynvml
Successfully installed pynvml-11.5.0


In [None]:
!nvidia-smi

Thu Jun 20 23:58:41 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   49C    P8              13W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class KANLinear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x):
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid = self.grid
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x, y):
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(0, 1)
        B = y.transpose(0, 1)
        solution = torch.linalg.lstsq(A, B).solution
        result = solution.permute(2, 0, 1)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x):
        original_shape = x.shape
        if x.dim() == 3:
            x = x.reshape(-1, x.size(-1))  # Flatten to 2D tensor

        assert x.dim() == 2 and x.size(1) == self.in_features
        self.acts = self.base_activation(x)  # Store activations


        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        if len(original_shape) == 3:
            output = output.reshape(original_shape[0], original_shape[1], -1)  # Reshape back to 3D tensor

        return output

    @torch.no_grad()
    def update_grid(self, x, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)
        splines = splines.permute(1, 0, 2)
        orig_coeff = self.scaled_spline_weight
        orig_coeff = orig_coeff.permute(1, 2, 0)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)
        unreduced_spline_output = unreduced_spline_output.permute(1, 0, 2)

        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.cat(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

class ImageToPatches(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.P = patch_size

    def forward(self, x):
        P = self.P
        B,C,H,W = x.shape                       # [B,C,H,W]                 4D Image
        x = x.reshape(B,C, H//P, P , W//P, P)   # [B,C, H//P, P, W//P, P]   6D Patches
        x = x.permute(0,2,4, 1,3,5)             # [B, H//P, W//P, C, P, P]  6D Swap Axes
        x = x.reshape(B, H//P * W//P, C*P*P)    # [B, H//P * W//P, C*P*P]   3D Patches
                                                # [B, n_tokens, n_pixels]
        return x

class PerPatchKAN(nn.Module):
    def __init__(self, n_pixels, n_channel):
        super().__init__()
        self.kan = KANLinear(n_pixels, n_channel)

    def forward(self, x):
        output = self.kan(x)
        return output  # x*w:  [B, n_tokens, n_pixels] x [n_pixels, n_channel]
                       #       [B, n_tokens, n_channel]

class TokenMixingKAN(nn.Module):
    def __init__(self, n_tokens, n_channel, n_hidden):
        super().__init__()
        self.layer_norm = nn.LayerNorm([n_tokens, n_channel])
        self.kan1 = KANLinear(n_tokens, n_hidden)
        self.kan2 = KANLinear(n_hidden, n_tokens)
        self.activations = None

    def forward(self, X):
        z = self.layer_norm(X)                  # z:    [B, n_tokens, n_channel]
        z = z.permute(0, 2, 1)                  # z:    [B, n_channel, n_tokens]
        z = self.kan1(z)                        # z:    [B, n_channel, n_hidden]
        z = self.kan2(z)                        # z:    [B, n_hidden, n_tokens]
        z = z.permute(0, 2, 1)                  # z:    [B, n_tokens, n_channel]
        U = X + z                               # U:    [B, n_tokens, n_channel]
        self.activations = U  # Store activations
        return U

class ChannelMixingKAN(nn.Module):
    def __init__(self, n_tokens, n_channel, n_hidden):
        super().__init__()
        self.layer_norm = nn.LayerNorm([n_tokens, n_channel])
        self.kan3 = KANLinear(n_channel, n_hidden)
        self.kan4 = KANLinear(n_hidden, n_channel)
        self.activations = None

    def forward(self, U):
        z = self.layer_norm(U)                  # z: [B, n_tokens, n_channel]
        z = self.kan3(z)                        # z: [B, n_tokens, n_hidden]
        z = self.kan4(z)                        # z: [B, n_tokens, n_channel]
        Y = U + z                               # Y: [B, n_tokens, n_channel]
        self.activations = Y  # Store activations
        return Y

class OutputKAN(nn.Module):
    def __init__(self, n_tokens, n_channel, n_output):
        super().__init__()
        self.layer_norm = nn.LayerNorm([n_tokens, n_channel])
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)  # Global average pooling
        self.out_kan = KANLinear(n_channel, n_output)
        self.activations = None

    def forward(self, x):
        x = self.layer_norm(x)                  # x: [B, n_tokens, n_channel]
        x = x.permute(0, 2, 1)                  # [B, n_tokens, n_channel] -> [B, n_channel, n_tokens]
        x = self.global_avg_pool(x)             # [B, n_channel, n_tokens] -> [B, n_channel, 1]
        x = x.squeeze(-1)                       # [B, n_channel, 1] -> [B, n_channel]
        output = self.out_kan(x)                # x: [B, n_output]
        self.activations = self.out_kan.acts  # Store activations
        return output

class KAN_Mixer(nn.Module):
    def __init__(self, n_layers, n_channel, n_hidden, n_output, image_size, patch_size, n_image_channel):
        super().__init__()

        n_tokens = (image_size // patch_size)**2
        n_pixels = n_image_channel * patch_size**2

        self.ImageToPatch = ImageToPatches(patch_size = patch_size)
        self.PerPatchKAN = PerPatchKAN(n_pixels, n_channel)
        self.MixerStack = nn.Sequential(*[
            nn.Sequential(
                TokenMixingKAN(n_tokens, n_channel, n_hidden),
                ChannelMixingKAN(n_tokens, n_channel, n_hidden)
            ) for _ in range(n_layers)
        ])
        self.OutputKAN = OutputKAN(n_tokens, n_channel, n_output)

    def forward(self, x):
        x = self.ImageToPatch(x)
        x = self.PerPatchKAN(x)
        x = self.MixerStack(x)
        return self.OutputKAN(x)

# Example usage:
n_layers = 2
n_channel = 16
n_hidden = 32
n_output = 10
image_size = 32
patch_size = 2
n_image_channel = 3

model = KAN_Mixer(n_layers, n_channel, n_hidden, n_output, image_size, patch_size, n_image_channel)

# Create a dummy input tensor with the shape [batch_size, n_image_channel, image_height, image_width]
x = torch.randn(1, n_image_channel, image_size, image_size)

# Forward pass
output = model(x)
print(output.shape)  # Should be [batch_size, n_output]

torch.Size([1, 10])


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, MNIST
from torch.utils.data import DataLoader, random_split
import time
import psutil
import os
import gc
import pandas as pd

# The full classes are assumed to be already provided based on your previous implementation

# Define transformations for the dataset
transform_cifar = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])

test_transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])

transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

def get_dataloaders(dataset, transform, test_transform, batch_size=32):
    if dataset == 'CIFAR10':
        train_dataset = CIFAR10(root='data', train=True, transform=transform, download=True)
        test_dataset = CIFAR10(root='data', train=False, transform=test_transform, download=True)
    elif dataset == 'CIFAR100':
        train_dataset = CIFAR100(root='data', train=True, transform=transform, download=True)
        test_dataset = CIFAR100(root='data', train=False, transform=test_transform, download=True)
    elif dataset == 'MNIST':
        train_dataset = MNIST(root='data', train=True, transform=transform, download=True)
        test_dataset = MNIST(root='data', train=False, transform=test_transform, download=True)
    else:
        raise ValueError("Dataset not supported")

    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

def calculate_memory():
    return psutil.virtual_memory()._asdict()

def train_and_evaluate(model, train_loader, val_loader, device, dataset, n_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.to(device)

    results = []

    for epoch in range(n_epochs):
        model.train()
        start_time = time.time()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        end_time = time.time()
        epoch_time = end_time - start_time

        model.eval()
        val_loss = 0
        correct = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, labels).item()
                pred = outputs.argmax(dim=1, keepdim=True)
                correct += pred.eq(labels.view_as(pred)).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = correct / len(val_loader.dataset)

        gpu_memory_allocated = torch.cuda.memory_allocated() / 1024**2
        gpu_memory_reserved = torch.cuda.memory_reserved() / 1024**2

        results.append({
            'dataset': dataset,
            'epoch': epoch + 1,
            'epoch_time': epoch_time,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'gpu_memory_allocated_MB': gpu_memory_allocated,
            'gpu_memory_reserved_MB': gpu_memory_reserved
        })

        print(f'Epoch {epoch+1}/{n_epochs} - Time: {epoch_time:.2f}s')
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        print(f'GPU memory usage: {gpu_memory_allocated:.2f} MB')
        print(f'GPU memory reserved: {gpu_memory_reserved:.2f} MB')

    torch.cuda.empty_cache()
    gc.collect()
    return results

def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    start_time = time.time()
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            test_loss += criterion(outputs, labels).item()
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
    end_time = time.time()
    test_time = end_time - start_time

    test_loss /= len(test_loader.dataset)
    test_accuracy = correct / len(test_loader.dataset)

    gpu_memory_allocated = torch.cuda.memory_allocated() / 1024**2
    gpu_memory_reserved = torch.cuda.memory_reserved() / 1024**2

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
    print(f'Test Time: {test_time:.2f}s')
    print(f'GPU memory usage during test: {gpu_memory_allocated:.2f} MB')
    print(f'GPU memory reserved during test: {gpu_memory_reserved:.2f} MB')

    return test_loss, test_accuracy, test_time, gpu_memory_allocated, gpu_memory_reserved

def main():
    datasets = ['CIFAR10', 'CIFAR100', 'MNIST']
    batch_size = 32
    n_epochs = 10
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    all_results = []

    for dataset in datasets:
        print(f"Training on {dataset}")
        if dataset == 'MNIST':
            train_loader, val_loader, test_loader = get_dataloaders(dataset, transform_mnist, transform_mnist, batch_size)
        else:
            train_loader, val_loader, test_loader = get_dataloaders(dataset, transform_cifar, test_transform_cifar, batch_size)

        n_layers = 2
        n_channel = 16
        n_hidden = 32
        n_output = 10 if dataset != 'CIFAR100' else 100
        image_size = 32
        patch_size = 2
        n_image_channel = 3 if dataset != 'MNIST' else 1

        model = KAN_Mixer(n_layers, n_channel, n_hidden, n_output, image_size, patch_size, n_image_channel)

        # Reset GPU memory stats
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        train_results = train_and_evaluate(model, train_loader, val_loader, device, dataset, n_epochs)
        test_loss, test_accuracy, test_time, test_gpu_allocated, test_gpu_reserved = test(model, test_loader, device)

        for result in train_results:
            result.update({
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'test_time': test_time,
                'test_gpu_memory_allocated_MB': test_gpu_allocated,
                'test_gpu_memory_reserved_MB': test_gpu_reserved
            })

        all_results.extend(train_results)

    # Save all results to one DataFrame
    df = pd.DataFrame(all_results)
    df.to_csv('all_datasets_results.csv', index=False)

if __name__ == "__main__":
    main()

Training on CIFAR10
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 41973103.67it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Epoch 1/10 - Time: 43.92s
Validation Loss: 0.0517, Validation Accuracy: 0.3934
GPU memory usage: 26.98 MB
GPU memory reserved: 418.00 MB
Epoch 2/10 - Time: 42.86s
Validation Loss: 0.0473, Validation Accuracy: 0.4511
GPU memory usage: 26.98 MB
GPU memory reserved: 418.00 MB
Epoch 3/10 - Time: 42.64s
Validation Loss: 0.0451, Validation Accuracy: 0.4739
GPU memory usage: 26.98 MB
GPU memory reserved: 418.00 MB
Epoch 4/10 - Time: 42.84s
Validation Loss: 0.0424, Validation Accuracy: 0.5134
GPU memory usage: 26.98 MB
GPU memory reserved: 418.00 MB
Epoch 5/10 - Time: 42.88s
Validation Loss: 0.0414, Validation Accuracy: 0.5253
GPU memory usage: 26.98 MB
GPU memory reserved: 418.00 MB
Epoch 6/10 - Time: 42.57s
Validation Loss: 0.0404, Validation Accuracy: 0.5322
GPU memory usage: 26.98 MB
GPU memory reserved: 418.00 MB
Epoch 7/10 - Time: 42.50s
Validation Loss: 0.0399, Validation Accuracy: 0.5358
GPU memory usa

100%|██████████| 169001437/169001437 [00:03<00:00, 44250469.49it/s]


Extracting data/cifar-100-python.tar.gz to data
Files already downloaded and verified
Epoch 1/10 - Time: 42.78s
Validation Loss: 0.1231, Validation Accuracy: 0.0964
GPU memory usage: 27.20 MB
GPU memory reserved: 428.00 MB
Epoch 2/10 - Time: 42.78s
Validation Loss: 0.1156, Validation Accuracy: 0.1297
GPU memory usage: 27.20 MB
GPU memory reserved: 428.00 MB
Epoch 3/10 - Time: 42.73s
Validation Loss: 0.1122, Validation Accuracy: 0.1492
GPU memory usage: 27.20 MB
GPU memory reserved: 428.00 MB
Epoch 4/10 - Time: 42.68s
Validation Loss: 0.1086, Validation Accuracy: 0.1641
GPU memory usage: 27.20 MB
GPU memory reserved: 428.00 MB
Epoch 5/10 - Time: 42.61s
Validation Loss: 0.1064, Validation Accuracy: 0.1764
GPU memory usage: 27.20 MB
GPU memory reserved: 428.00 MB
Epoch 6/10 - Time: 42.51s
Validation Loss: 0.1053, Validation Accuracy: 0.1918
GPU memory usage: 27.20 MB
GPU memory reserved: 428.00 MB
Epoch 7/10 - Time: 42.74s
Validation Loss: 0.1035, Validation Accuracy: 0.1910
GPU memory us

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, MNIST
from torch.utils.data import DataLoader, random_split
import time
import psutil
import os
import gc
import pandas as pd

# The full classes are assumed to be already provided based on your previous implementation

# Define transformations for the dataset
transform_cifar = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])

test_transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])

transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

def get_dataloaders(dataset, transform, test_transform, batch_size=32):
    if dataset == 'CIFAR10':
        train_dataset = CIFAR10(root='data', train=True, transform=transform, download=True)
        test_dataset = CIFAR10(root='data', train=False, transform=test_transform, download=True)
    elif dataset == 'CIFAR100':
        train_dataset = CIFAR100(root='data', train=True, transform=transform, download=True)
        test_dataset = CIFAR100(root='data', train=False, transform=test_transform, download=True)
    elif dataset == 'MNIST':
        train_dataset = MNIST(root='data', train=True, transform=transform, download=True)
        test_dataset = MNIST(root='data', train=False, transform=test_transform, download=True)
    else:
        raise ValueError("Dataset not supported")

    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

def calculate_memory():
    return psutil.virtual_memory()._asdict()

def train_and_evaluate(model, train_loader, val_loader, device, dataset, n_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.to(device)

    results = []

    for epoch in range(n_epochs):
        model.train()
        start_time = time.time()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        end_time = time.time()
        epoch_time = end_time - start_time

        model.eval()
        val_loss = 0
        correct = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, labels).item()
                pred = outputs.argmax(dim=1, keepdim=True)
                correct += pred.eq(labels.view_as(pred)).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = correct / len(val_loader.dataset)

        gpu_memory_allocated = torch.cuda.memory_allocated() / 1024**2
        gpu_memory_reserved = torch.cuda.memory_reserved() / 1024**2

        results.append({
            'dataset': dataset,
            'epoch': epoch + 1,
            'epoch_time': epoch_time,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'gpu_memory_allocated_MB': gpu_memory_allocated,
            'gpu_memory_reserved_MB': gpu_memory_reserved
        })

        print(f'Epoch {epoch+1}/{n_epochs} - Time: {epoch_time:.2f}s')
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        print(f'GPU memory usage: {gpu_memory_allocated:.2f} MB')
        print(f'GPU memory reserved: {gpu_memory_reserved:.2f} MB')

    torch.cuda.empty_cache()
    gc.collect()
    return results

def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    start_time = time.time()
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            test_loss += criterion(outputs, labels).item()
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
    end_time = time.time()
    test_time = end_time - start_time

    test_loss /= len(test_loader.dataset)
    test_accuracy = correct / len(test_loader.dataset)

    gpu_memory_allocated = torch.cuda.memory_allocated() / 1024**2
    gpu_memory_reserved = torch.cuda.memory_reserved() / 1024**2

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
    print(f'Test Time: {test_time:.2f}s')
    print(f'GPU memory usage during test: {gpu_memory_allocated:.2f} MB')
    print(f'GPU memory reserved during test: {gpu_memory_reserved:.2f} MB')

    return test_loss, test_accuracy, test_time, gpu_memory_allocated, gpu_memory_reserved

def main():
    datasets = ['CIFAR10', 'CIFAR100', 'MNIST']
    batch_size = 32
    n_epochs = 10
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    all_results = []

    for dataset in datasets:
        print(f"Training on {dataset}")
        if dataset == 'MNIST':
            train_loader, val_loader, test_loader = get_dataloaders(dataset, transform_mnist, transform_mnist, batch_size)
        else:
            train_loader, val_loader, test_loader = get_dataloaders(dataset, transform_cifar, test_transform_cifar, batch_size)

        n_layers = 2
        n_channel = 32
        n_hidden = 64
        n_output = 10 if dataset != 'CIFAR100' else 100
        image_size = 32
        patch_size = 2
        n_image_channel = 3 if dataset != 'MNIST' else 1

        model = KAN_Mixer(n_layers, n_channel, n_hidden, n_output, image_size, patch_size, n_image_channel)

        # Reset GPU memory stats
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        train_results = train_and_evaluate(model, train_loader, val_loader, device, dataset, n_epochs)
        test_loss, test_accuracy, test_time, test_gpu_allocated, test_gpu_reserved = test(model, test_loader, device)

        for result in train_results:
            result.update({
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'test_time': test_time,
                'test_gpu_memory_allocated_MB': test_gpu_allocated,
                'test_gpu_memory_reserved_MB': test_gpu_reserved
            })

        all_results.extend(train_results)

    # Save all results to one DataFrame
    df = pd.DataFrame(all_results)
    df.to_csv('all_datasets_results.csv', index=False)

if __name__ == "__main__":
    main()

Training on CIFAR10
Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10 - Time: 43.28s
Validation Loss: 0.0479, Validation Accuracy: 0.4397
GPU memory usage: 55.68 MB
GPU memory reserved: 846.00 MB
Epoch 2/10 - Time: 43.11s
Validation Loss: 0.0440, Validation Accuracy: 0.4830
GPU memory usage: 55.68 MB
GPU memory reserved: 846.00 MB
Epoch 3/10 - Time: 43.20s
Validation Loss: 0.0414, Validation Accuracy: 0.5222
GPU memory usage: 55.68 MB
GPU memory reserved: 846.00 MB
Epoch 4/10 - Time: 43.36s
Validation Loss: 0.0396, Validation Accuracy: 0.5459
GPU memory usage: 55.68 MB
GPU memory reserved: 846.00 MB
Epoch 5/10 - Time: 43.15s
Validation Loss: 0.0375, Validation Accuracy: 0.5691
GPU memory usage: 55.68 MB
GPU memory reserved: 846.00 MB
Epoch 6/10 - Time: 43.02s
Validation Loss: 0.0368, Validation Accuracy: 0.5740
GPU memory usage: 55.68 MB
GPU memory reserved: 846.00 MB
Epoch 7/10 - Time: 43.24s
Validation Loss: 0.0359, Validation Accuracy: 0.5934
GPU

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, MNIST
from torch.utils.data import DataLoader, random_split
import time
import psutil
import os
import gc
import pandas as pd

# The full classes are assumed to be already provided based on your previous implementation

# Define transformations for the dataset
transform_cifar = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])

test_transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])

transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

def get_dataloaders(dataset, transform, test_transform, batch_size=32):
    if dataset == 'CIFAR10':
        train_dataset = CIFAR10(root='data', train=True, transform=transform, download=True)
        test_dataset = CIFAR10(root='data', train=False, transform=test_transform, download=True)
    elif dataset == 'CIFAR100':
        train_dataset = CIFAR100(root='data', train=True, transform=transform, download=True)
        test_dataset = CIFAR100(root='data', train=False, transform=test_transform, download=True)
    elif dataset == 'MNIST':
        train_dataset = MNIST(root='data', train=True, transform=transform, download=True)
        test_dataset = MNIST(root='data', train=False, transform=test_transform, download=True)
    else:
        raise ValueError("Dataset not supported")

    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

def calculate_memory():
    return psutil.virtual_memory()._asdict()

def train_and_evaluate(model, train_loader, val_loader, device, dataset, n_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.to(device)

    results = []

    for epoch in range(n_epochs):
        model.train()
        start_time = time.time()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        end_time = time.time()
        epoch_time = end_time - start_time

        model.eval()
        val_loss = 0
        correct = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, labels).item()
                pred = outputs.argmax(dim=1, keepdim=True)
                correct += pred.eq(labels.view_as(pred)).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = correct / len(val_loader.dataset)

        gpu_memory_allocated = torch.cuda.memory_allocated() / 1024**2
        gpu_memory_reserved = torch.cuda.memory_reserved() / 1024**2

        results.append({
            'dataset': dataset,
            'epoch': epoch + 1,
            'epoch_time': epoch_time,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'gpu_memory_allocated_MB': gpu_memory_allocated,
            'gpu_memory_reserved_MB': gpu_memory_reserved
        })

        print(f'Epoch {epoch+1}/{n_epochs} - Time: {epoch_time:.2f}s')
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        print(f'GPU memory usage: {gpu_memory_allocated:.2f} MB')
        print(f'GPU memory reserved: {gpu_memory_reserved:.2f} MB')

    torch.cuda.empty_cache()
    gc.collect()
    return results

def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    start_time = time.time()
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            test_loss += criterion(outputs, labels).item()
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
    end_time = time.time()
    test_time = end_time - start_time

    test_loss /= len(test_loader.dataset)
    test_accuracy = correct / len(test_loader.dataset)

    gpu_memory_allocated = torch.cuda.memory_allocated() / 1024**2
    gpu_memory_reserved = torch.cuda.memory_reserved() / 1024**2

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
    print(f'Test Time: {test_time:.2f}s')
    print(f'GPU memory usage during test: {gpu_memory_allocated:.2f} MB')
    print(f'GPU memory reserved during test: {gpu_memory_reserved:.2f} MB')

    return test_loss, test_accuracy, test_time, gpu_memory_allocated, gpu_memory_reserved

def main():
    datasets = ['CIFAR10', 'CIFAR100', 'MNIST']
    batch_size = 32
    n_epochs = 10
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    all_results = []

    for dataset in datasets:
        print(f"Training on {dataset}")
        if dataset == 'MNIST':
            train_loader, val_loader, test_loader = get_dataloaders(dataset, transform_mnist, transform_mnist, batch_size)
        else:
            train_loader, val_loader, test_loader = get_dataloaders(dataset, transform_cifar, test_transform_cifar, batch_size)

        n_layers = 2
        n_channel = 64
        n_hidden = 128
        n_output = 10 if dataset != 'CIFAR100' else 100
        image_size = 32
        patch_size = 2
        n_image_channel = 3 if dataset != 'MNIST' else 1

        model = KAN_Mixer(n_layers, n_channel, n_hidden, n_output, image_size, patch_size, n_image_channel)

        # Reset GPU memory stats
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        train_results = train_and_evaluate(model, train_loader, val_loader, device, dataset, n_epochs)
        test_loss, test_accuracy, test_time, test_gpu_allocated, test_gpu_reserved = test(model, test_loader, device)

        for result in train_results:
            result.update({
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'test_time': test_time,
                'test_gpu_memory_allocated_MB': test_gpu_allocated,
                'test_gpu_memory_reserved_MB': test_gpu_reserved
            })

        all_results.extend(train_results)

    # Save all results to one DataFrame
    df = pd.DataFrame(all_results)
    df.to_csv('all_datasets_results.csv', index=False)

if __name__ == "__main__":
    main()

Training on CIFAR10
Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10 - Time: 110.14s
Validation Loss: 0.0464, Validation Accuracy: 0.4614
GPU memory usage: 77.56 MB
GPU memory reserved: 1704.00 MB
Epoch 2/10 - Time: 110.16s
Validation Loss: 0.0407, Validation Accuracy: 0.5320
GPU memory usage: 77.56 MB
GPU memory reserved: 1704.00 MB
Epoch 3/10 - Time: 110.15s
Validation Loss: 0.0392, Validation Accuracy: 0.5464
GPU memory usage: 77.56 MB
GPU memory reserved: 1704.00 MB
Epoch 4/10 - Time: 110.16s
Validation Loss: 0.0368, Validation Accuracy: 0.5792
GPU memory usage: 77.56 MB
GPU memory reserved: 1704.00 MB
Epoch 5/10 - Time: 110.13s
Validation Loss: 0.0351, Validation Accuracy: 0.5996
GPU memory usage: 77.56 MB
GPU memory reserved: 1704.00 MB
Epoch 6/10 - Time: 110.17s
Validation Loss: 0.0339, Validation Accuracy: 0.6132
GPU memory usage: 77.56 MB
GPU memory reserved: 1704.00 MB
Epoch 7/10 - Time: 110.12s
Validation Loss: 0.0329, Validation Accurac