In [12]:
from torchvision.models import get_model

In [15]:
classifier = get_model('convnext_base', num_classes=10)

{'__name__': 'torchvision.models',
 '__doc__': None,
 '__package__': 'torchvision.models',
 '__loader__': <_frozen_importlib_external.SourceFileLoader at 0x7fbc92ccb970>,
 '__spec__': ModuleSpec(name='torchvision.models', loader=<_frozen_importlib_external.SourceFileLoader object at 0x7fbc92ccb970>, origin='/opt/conda/lib/python3.10/site-packages/torchvision/models/__init__.py', submodule_search_locations=['/opt/conda/lib/python3.10/site-packages/torchvision/models']),
 '__path__': ['/opt/conda/lib/python3.10/site-packages/torchvision/models'],
 '__file__': '/opt/conda/lib/python3.10/site-packages/torchvision/models/__init__.py',
 '__cached__': '/opt/conda/lib/python3.10/site-packages/torchvision/models/__pycache__/__init__.cpython-310.pyc',
 '__builtins__': {'__name__': 'builtins',
  '__doc__': "Built-in functions, exceptions, and other objects.\n\nNoteworthy: None is the `nil' object; Ellipsis represents `...' in slices.",
  '__package__': '',
  '__loader__': _frozen_importlib.Builti

In [17]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.1.4-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.2/57.2 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl.metadata (4.8 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.3.0.post0-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.1.4-py3-none-any.whl.metadata (21 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]<2025.0,>=2022.5.0->lightning)
  Downloading aiohttp-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.4 kB)
Collecting aiosignal>=1.1.2 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2025.0,>=2022.5.0->lightning)
  Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Collecting frozenlist>=1.1.1 (from aiohttp!=4.0.0a0,!=4.0.0a1-

In [None]:
from torch.optim.lr_scheduler import _LRScheduler

class LinearWarmUpMultiStepDecay(_LRScheduler):
    def __init__(self, optimizer, milestones, gamma=0.1, warmup_steps=5, warmup_start_lr=0, last_epoch=-1, verbose=False):
        # milestones: list of steps to decay LR
        # gamma: decay factor
        # warmup_iters: number of epochs for warmup
        # warmup_start_lr: initial learning rate for warmup
        
        self.milestones = milestones
        self.gamma = gamma
        self.warmup_steps = warmup_steps
        self.warmup_start_lr = warmup_start_lr

        super(LinearWarmUpMultiStepDecay, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup
            alpha = self.last_epoch / self.warmup_steps
            scale = (1 - alpha) * self.warmup_start_lr + alpha
            return [base_lr * scale for base_lr in self.base_lrs]
        else:
            # Multi-step decay
            return [base_lr * self.gamma ** sum(epoch < self.last_epoch for epoch in self.milestones) for base_lr in self.base_lrs]

In [None]:
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import torch
from torch.nn import functional as F
from torchvision.models import get_model

class LitModel(pl.LightningModule):
    def __init__(self, model, num_classes, initial_lr=0.02, milestones=[30, 60, 90], warmup_steps=5):
        super().__init__()
        self.classifier = get_model(model, num_classes=num_classes)
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        return loss
    
    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx)
    
    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.initial_lr)
        scheduler = LinearWarmUpMultiStepDecay(optimizer, milestones=self.hparams.milestones, warmup_steps=self.hparams.warmup_steps)
        return [optimizer], [scheduler]

class LitDataModule(pl.LightningDataModule):
    def __init__(self, ):
        pass

    def setup(self):
        pass

    def prepare_data(self):
        pass

    def train_dataloader(self):
        pass

    def val_dataloader(self):  
        pass
    
    def test_dataloader(self):
        pass