In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBar

input_size = 784 
hidden_size = 500 
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001

In [None]:
class LitNeuralNet(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LitNeuralNet, self).__init__()
        self.input_size = input_size
        self.l1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)
        self.validation_step_outputs = []
    
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(-1, 28*28)
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("loss", loss, prog_bar=True, on_step=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    
    def train_dataloader(self):
        train_dataset = torchvision.datasets.MNIST(root='data', train=True, 
                                           transform=transforms.ToTensor(), download=True)

        train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, 
                                                   shuffle=True, num_workers=4)
        return train_loader
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(-1, 28*28)
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.validation_step_outputs.append(loss)
        return loss
    
    def val_dataloader(self):
        val_dataset = torchvision.datasets.MNIST(root='data', train=False, 
                                           transform=transforms.ToTensor(), download=True)

        val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, 
                                                   shuffle=False, num_workers=4)
        return val_loader
    
    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.validation_step_outputs).mean()
        self.validation_step_outputs.clear() # free memory
        return avg_loss
    

# determenistic, gradient_clip_val, auto_lr_find and more
trainer = Trainer(max_epochs=num_epochs, fast_dev_run=False)
model = LitNeuralNet(input_size, hidden_size, num_classes)
trainer.fit(model)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBar

input_size = 784 
hidden_size = 50 
num_classes = 10
num_epochs = 20
batch_size = 100
learning_rate = 0.001

In [None]:
class LitNeuralNet(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes, lr):
        super().__init__()

        self.lr = lr
        
        self.model = nn.Sequential(
            # nn.Flatten(),
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
        
        self.train_step_outputs = []
        self.train_step_labels = []
        
        self.val_step_outputs = []
        self.val_step_labels = []
        
        self.test_step_outputs = []
        self.test_step_outputs = []
        
        self.test_correct = 0
        self.test_total = 0
    
    def forward(self, x):
        return self.model(x)
    
    ####################################################################################
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(-1, 28*28)
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        
        pred_labels = y_hat.argmax(dim=1)
        correct = (pred_labels == y).sum().item()
        acc = correct / len(y)
        self.log("train loss", loss, prog_bar=True, on_step=True)
        self.log("train acc", acc, prog_bar=True, on_step=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(-1, 28*28)
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        pred_labels = y_hat.argmax(dim=1)
        correct = (pred_labels == y).sum().item()
        acc = correct / len(y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(-1, 28 * 28)
        y_hat = self(x)
        
        pred_labels = y_hat.argmax(dim=1)
        correct = (pred_labels == y).sum().item()
        self.test_correct += correct
        self.test_total += len(y)

        self.log("test_acc", self.test_correct / self.test_total, prog_bar=True)
        return pred_labels, y
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    ####################################################################################
    
    def on_train_epoch_end(self):
        pass
    #     all_preds = torch.stack(self.training_step_outputs)
    #     # ...
    #     self.training_step_outputs.clear()  # free memory
    
    def on_validation_epoch_end(self):
        pass
    
    def on_test_epoch_end(self):
        pass

    ####################################################################################
    
    def train_dataloader(self):
        train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader
    
    def val_dataloader(self):
        val_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())
        val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        return val_loader
    
    def test_dataloader(self):
        test_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        return test_loader

    ####################################################################################

                                
model = LitNeuralNet(input_size, hidden_size, num_classes, learning_rate)

# determenistic, gradient_clip_val, auto_lr_find and more
trainer = Trainer(max_epochs=num_epochs, fast_dev_run=False)
trainer.fit(model)

trainer.test(model) # automatically loads the best weights

In [None]:
# !lightning --version

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir lightning_logs