In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers.wandb import WandbLogger


In [None]:
class LightningMNISTClassifier(pl.LightningModule):
    
    def __init__(self):
        super(LightningMNISTClassifier, self).__init__()

        # MNIST samples are [1, 28, 28]
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        
        x = x.view(batch_size, -1)
        
        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=-1)

        return x

    def cross_entropy_loss(self, logits, labels):
        return nn.functional.nll_loss(logits, labels)

    # def accuracy(self, logits, labels):
    #     return nn.functional.

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.logger.experiment.log({'train_loss': loss})

        return {'loss': loss}

    def validation_step(self, valid_batch, batch_idx):
        x, y = valid_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.logger.experiment.log({'valid_loss': loss})

        return {'valid_loss': loss}

    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        acc = self.accuracy

    # This is called at the end of each validation epoch. Why?
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['valid_loss'] for x in outputs]).mean()
        self.logger.experiment.log({'avg_valid_loss': avg_loss})

        return {'avg_valid_loss': avg_loss}

    def prepare_data(self):
        MNIST(os.getcwd(), train=True, download=True)
        MNIST(os.getcwd(), train=False, download=True)

    def train_dataloader(self):
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
        self.mnist_train, self.mnist_valid = random_split(mnist_train, [55000, 5000])
        mnist_train = DataLoader(self.mnist_train, batch_size=1000)
        return mnist_train
    
    def val_dataloader(self):
        mnist_valid = DataLoader(self.mnist_valid, batch_size=1000)
        return mnist_valid

    def test_dataloader(self):
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        mnist_test = MNIST(os.getcwd(), train=False, batch_size=1000)
        mnist_test = DataLoader(mnist_test, batch_size=1000)
        return mnist_test

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), 0.001)
        return optimizer


In [None]:
wandb.init(name='MNIST-32-bit-adam-0.001',project='MNIST-test')

In [None]:
wandb_logger = WandbLogger(name='MNIST-32-bit-adam-0.001',project='MNIST-test')

In [None]:
wandb.config = {
    'input_shape': (1, 28, 28),
    'layer_1_size': 128,
    'layer_2_size': 256,
    'layer_3_size': 10,
}
model = LightningMNISTClassifier()
model.train_dataloader()
# callbacks
ckpt_cb = pl.callbacks.ModelCheckpoint(
    monitor='avg_valid_loss',
    dirpath='c:\\Users\\mickey\\dev\\pytorch-lightning',
    filename='mnist-{epoch:03d}-{val_loss:.3f}',
)
es_cb = pl.callbacks.early_stopping.EarlyStopping(
    monitor='avg_valid_loss',

)
trainer = pl.Trainer(max_epochs=100, logger = wandb_logger, gpus=1, callbacks=[ckpt_cb, es_cb])

In [None]:
trainer.fit(model)