In [29]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torch.utils.data import DataLoader
import lightning as L

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [30]:
# MNIST
train_dataset_mnist = MNIST(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
train_loader_mnist = DataLoader(train_dataset_mnist)
test_dataset_mnist = MNIST(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
test_loader_mnist = DataLoader(test_dataset_mnist, batch_size=1, shuffle=False)

# CIFAR10
train_dataset_cifar10 = CIFAR10(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
train_loader_cifar10 = DataLoader(train_dataset_cifar10)
test_dataset_cifar10 = CIFAR10(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
test_loader_cifar10 = DataLoader(test_dataset_mnist, batch_size=1, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to c:\Users\hugom\Documents\DTU\02456\project\02456-kan-ntk-project\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:07<00:00, 21584759.58it/s]


Extracting c:\Users\hugom\Documents\DTU\02456\project\02456-kan-ntk-project\cifar-10-python.tar.gz to c:\Users\hugom\Documents\DTU\02456\project\02456-kan-ntk-project
Files already downloaded and verified


In [33]:
train_dataset_mnist[0][0].size(), train_dataset_cifar10[0][0].size()

(torch.Size([1, 28, 28]), torch.Size([3, 32, 32]))

In [23]:
class simpleMLP(nn.Module):
    def __init__(self, input_sz, num_classes):
        super().__init__()
        self.NN = nn.Sequential(
            nn.Linear(input_sz, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, num_classes), nn.Softmax(dim=1)  # Specify dim for Softmax
        )

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)  # Flatten keeping the batch dimension
        x = self.NN(x)
        return x

In [None]:
class LitModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, targets = batch
        x = self.model(x)
        loss = F.cross_entropy(x, targets)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
### MNIST ###

# model
mnist_model = LitModel(simpleMLP(28*28, 10)).to(device)

# train model
trainer = L.Trainer()
trainer.fit(model=mnist_model, train_dataloaders=train_loader_mnist)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\hugom\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\loops\utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name  | Type      | Params | Mode 
--------------------------------------------
0 | model | simpleMLP | 300 K  | train
--------------------------------------------
300 K     Trainable params
0         Non-trainable params
300 K     Total params
1.204     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode
c:\Users\hugom\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoa

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
### CIFAR10 ###

# model
cifar10_model = LitModel(simpleMLP(3*32*32, 10)).to(device)

# train model
trainer = L.Trainer()
trainer.fit(model=cifar10_model, train_dataloaders=train_loader_cifar10)