In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

import pytorch_lightning as pl

In [2]:
def set_device():
    device = {
        True: torch.device('mps'),
        False: torch.device('cpu')
    }
    return device[torch.backends.mps.is_available()]

In [3]:
device = set_device()

In [4]:
train = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
test = torchvision.datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())

In [5]:
bs = 32

train_loader = DataLoader(train, batch_size=bs, shuffle=True)
test_loader = DataLoader(test, batch_size=bs)

In [6]:
class CNN(pl.LightningModule):
    def __init__(self, lr=0.01, n_classes=10, n_filters=16, kernel_size=3):
        super().__init__()
        self.lr = lr
        self.n_classes = n_classes
        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.loss_func = nn.CrossEntropyLoss()
        
        self.c = 1
        self.h = 28
        self.w = 28
        
        self.model = nn.Sequential(
            nn.Conv2d(self.c, self.n_filters, self.kernel_size, padding=self.kernel_size//2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(self.n_filters*self.h*self.w, self.n_classes)
        )
        
    def get_model(self, x):
        _, c, h, w = x.shape  # comes in with batch dim
        return nn.Sequential(
            nn.Conv2d(c, self.n_filters, self.kernel_size, padding=self.kernel_size//2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(self.n_filters*h*w, self.n_classes)
    )
    
    
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, idx):
        x, y = batch
        pred = self.forward(x)
        loss = self.loss_func(pred, y)
        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
        

In [7]:
x, y = next(iter(train_loader))

In [8]:
x.shape

torch.Size([32, 1, 28, 28])

In [9]:
m = CNN()
m.forward(x).shape

torch.Size([32, 10])

In [10]:
model = CNN()
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model=model, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params
-----------------------------------------------
0 | loss_func | CrossEntropyLoss | 0     
1 | model     | Sequential       | 125 K 
-----------------------------------------------
125 K     Trainable params
0         Non-trainable params
125 K     Total params
0.502     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]