In [1]:
from sklearn import datasets

import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader

import lightning.pytorch as pl

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [76]:
# Define the autoencoder class
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 32),
            nn.ReLU(True),
            nn.Linear(32, 18)
        )
        self.decoder = nn.Sequential(
            nn.Linear(18, 32),
            nn.ReLU(True),
            nn.Linear(32, 28 * 28)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    
class AutoEncoderTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoEncoder()

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.flatten(start_dim=1)
        x_hat = self.model(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        x = x.flatten(start_dim=1)
        x_hat = self.model(x)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss, on_step=True, on_epoch=True, sync_dist=True)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.flatten(start_dim=1)
        x_hat = self.model(x)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer
        
    def predict_step(self, batch, batch_idx):
        return self(batch)

In [142]:
from torchmetrics import ExplainedVariance, Accuracy

# Define a Conv Classifier
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, (3, 3)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Flatten(),
            nn.LazyLinear(50),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        return self.model(x)

class ConvNetTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = ConvNet()

        self.accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        predictions = torch.argmax(logits, dim=1)
        batch_value = self.accuracy(predictions, y)
        self.log("train_acc", batch_value)
        return loss

    def on_train_epoch_end(self):
        self.accuracy.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        test_loss = F.cross_entropy(logits, y)
        test_accuracy = torch.sum(torch.argmax(logits, dim=1) == y) / len(y)
        self.log("test_loss", test_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log("test_acc", test_accuracy, on_step=False, on_epoch=True, sync_dist=True)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        val_loss = F.cross_entropy(logits, y)
        val_accuracy = torch.sum(torch.argmax(logits, dim=1) == y) / len(y)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log("val_acc", val_accuracy, on_step=False, on_epoch=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer
        
    def predict_step(self, batch, batch_idx):
        return self(batch)

In [138]:
net = ConvNet()
X, y = next(iter(train_loader))
logits = net(X)
F.cross_entropy(logits, y)
logits.shape

torch.Size([64, 10])

In [116]:
import torch.utils.data as data
from torchvision.datasets import MNIST
from torchvision import transforms

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str='.'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    # automatically called after prepare_data; don't assign state
    def setup(self, stage: str):
        if stage == "fit":
            train_dataset = MNIST(self.data_dir, train=True, transform=self.transform)
            train_set_size = int(len(train_dataset) * 0.8)
            valid_set_size = len(train_dataset) - train_set_size
            self.train_dataset, self.val_dataset = data.random_split(train_dataset, [train_set_size, valid_set_size])
        else:
            self.test_dataset = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=64)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64)

In [136]:
# load recent checkpoint and eval
model = ConvNetTrainer.load_from_checkpoint(r"lightning_logs/version_27/checkpoints/epoch=3-step=3000.ckpt")
#model = AutoEncoderTrainer()
trainer = pl.Trainer()

# [!] predict method not working 
trainer.test(model, datamodule=MNISTDataModule())

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_loss_epoch': 0.05850910022854805,
  'test_acc_epoch': 0.9817000031471252}]

In [143]:
from lightning.pytorch.profilers import AdvancedProfiler
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import DeviceStatsMonitor

## early stopping hook performed after each validation step by default
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=3, verbose=False, mode="max")
## advanced profiler to stream to file
# profiler = AdvancedProfiler(dirpath=".", filename="perf_logs")

# train model
model = ConvNetTrainer()
model.train()
trainer = pl.Trainer(
        default_root_dir=".",
        max_epochs=5,
        accelerator='gpu',
        devices='auto',
        callbacks=[early_stop_callback, DeviceStatsMonitor()],
        profiler='simple',      # 'advanced' to see time spent in each function; stream to a file later using AdvancedProfiler
        fast_dev_run=False,      # run a quick epoch to test for bugs
)
trainer.fit(model, datamodule=MNISTDataModule())

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | model    | ConvNet            | 830   
1 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
830       Trainable params
0         Non-trainable params
830       Total params
0.003     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

FIT Profiler Report

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                          	|  -              	|  157850         

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [None]:
from lightning.pytorch.utilities.model_summary import ModelSummary

summary = ModelSummary(model, max_depth=-1)
print(summary)

In [None]:
# subclass EarlyStopping and override the methods
class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # override this to disable early stopping at the end of val loop
        pass

    def on_train_end(self, trainer, pl_module):
        # instead, do it at the end of training loop
        self._run_early_stopping_check(trainer)