In [None]:
import pytorch_lightning as pl
from pytorch_lightning.profilers import PyTorchProfiler

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.profiler import schedule

from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms

In [None]:
transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=True, 
    download=True, 
    transform=transforms
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.model = model
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [None]:
log_dir = "./profiler_logs"

profiler = PyTorchProfiler(
    dirpath=log_dir,
    filename="profile_report",
    schedule=schedule(
        wait=2,     # skip first steps (startup noise)
        warmup=2,   # warm-up steps
        active=6,   # recorded steps
        repeat=1
    ),
    profile_memory=True,
    record_shapes=True
)

trainer = pl.Trainer(
    max_steps=14,
    accelerator="auto",
    devices=1,
    profiler=profiler,
    logger=False,
    enable_checkpointing=False,
    enable_model_summary=False,
    enable_progress_bar=False
)

model = 
lit_model = LitModel(lr=1e-3)
trainer.fit(lit_model, train_loader)
