In [1]:
import mapd
import torchvision

In [2]:
MNIST_ROOT = "data"
torchvision.datasets.MNIST(root=MNIST_ROOT, download=True)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train

In [3]:
from torchvision.datasets import MNIST

In [4]:
from torch.utils.data import Dataset

class IDXDataset(Dataset):
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index], index

In [5]:
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x
    
model = Net()

In [9]:
import lightning as L
from torch.nn import functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch


class ResNet18(mapd.MAPDModule):
    def __init__(
        self,
        max_epochs: int = 10,
        lr: float = 0.05,
        momentum: float = 0.9,
        weight_decay: float = 0.0005
    ):
        super().__init__()
        self.model = model

        self.max_epochs = max_epochs
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.save_hyperparameters(ignore=["model"])

    def forward(self, x):
        return self.model(x)
    
    def batch_loss(self, logits, y) -> torch.Tensor:
        return F.cross_entropy(logits, y, reduction="none")

    def training_step(self, batch, batch_idx):
        x, y = batch
        
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.log_batch_loss(logits, y)
        
        return loss

    def configure_optimizers(self):
        optimizer = SGD(
            self.parameters(),
            lr=self.lr
        )

        return {"optimizer": optimizer}

In [7]:
from torch.utils.data import random_split, DataLoader
from torchvision import transforms

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "data", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
        self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        self.mnist_train = IDXDataset(self.mnist_train)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

In [8]:
module = ResNet18()
dm = MNISTDataModule()

trainer = L.Trainer(accelerator="cpu")

trainer.fit(module, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 21.8 K
-------------------------------
21.8 K    Trainable params
0         Non-trainable params
21.8 K    Total params
0.087     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]



tensor(2.0119, grad_fn=<MeanBackward0>)
torch.Size([55000])
torch.Size([55000])
tensor(2.2761, grad_fn=<MeanBackward0>)
torch.Size([55000])
torch.Size([55000])
tensor(2.3075, grad_fn=<MeanBackward0>)
torch.Size([55000])
torch.Size([55000])
tensor(2.3068, grad_fn=<MeanBackward0>)
torch.Size([55000])
torch.Size([55000])


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
