In [1]:
pip install -U 'mlflow>=2.22.0'

Collecting mlflow>=2.22.0
  Downloading mlflow-2.22.0-py3-none-any.whl.metadata (30 kB)
Collecting mlflow-skinny==2.22.0 (from mlflow>=2.22.0)
  Downloading mlflow_skinny-2.22.0-py3-none-any.whl.metadata (31 kB)
Collecting alembic!=1.10.0,<2 (from mlflow>=2.22.0)
  Downloading alembic-1.16.1-py3-none-any.whl.metadata (7.3 kB)
Collecting docker<8,>=4.0.0 (from mlflow>=2.22.0)
  Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)
Collecting graphene<4 (from mlflow>=2.22.0)
  Downloading graphene-3.4.3-py2.py3-none-any.whl.metadata (6.9 kB)
Collecting gunicorn<24 (from mlflow>=2.22.0)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==2.22.0->mlflow>=2.22.0)
  Downloading databricks_sdk-0.55.0-py3-none-any.whl.metadata (39 kB)
Collecting opentelemetry-api<3,>=1.9.0 (from mlflow-skinny==2.22.0->mlflow>=2.22.0)
  Downloading opentelemetry_api-1.34.0-py3-none-any.whl.metadata (1.5 kB)
Collecting opentelemetry-sdk

In [2]:
pip install lightning

Collecting lightning
  Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.7.2-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata

In [3]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
#from wide_resnet import WideResnetLit
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger

L.seed_everything(42)
torch.set_float32_matmul_precision("medium")

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
import torchmetrics


class BasicBlock(nn.Module):
    """Basic ResNet block for CIFAR with post-activation ordering (conv → BN → ReLU)."""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_planes,
            out_channels=planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            in_channels=planes,
            out_channels=planes,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.downsample = None
        if stride != 1 or in_planes != planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_planes,
                    out_channels=planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes),
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out, inplace=True)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = F.relu(out, inplace=True)
        return out


class WideResNet(nn.Module):
    """
    WideResNet for CIFAR-style inputs. Implements WRN-32-k by default:
    - depth = 6n + 2, here n = 5 → depth = 32
    - widen_factor = k = 4 → widths [16*k, 32*k, 64*k] in the three groups
    """

    def __init__(self, depth=32, widen_factor=4, num_classes=10):
        super(WideResNet, self).__init__()
        assert (depth - 2) % 6 == 0, "Depth should be 6n+2 for CIFAR ResNet variants"
        n = (depth - 2) // 6

        self.in_planes = 16

        self.conv1 = nn.Conv2d(
            in_channels=3, # CHANGE
            out_channels=self.in_planes,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(self.in_planes)

        widths = [16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        self.layer1 = self._make_layer(planes=widths[0], num_blocks=n, stride=1)
        self.layer2 = self._make_layer(planes=widths[1], num_blocks=n, stride=2)
        self.layer3 = self._make_layer(planes=widths[2], num_blocks=n, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(widths[2] * BasicBlock.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, planes, num_blocks, stride):
        """
        Create one group (layer) consisting of `num_blocks` BasicBlocks.
        The first block may downsample with stride > 1; subsequent blocks use stride=1.
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, stride=s))
            self.in_planes = planes * BasicBlock.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x, inplace=True)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


class WideResnetLit(L.LightningModule):
    """
    PyTorch Lightning module wrapping the WideResNet model for CIFAR-10.
    Includes training/validation/test steps, optimizer, and LR scheduler.
    """

    def __init__(
        self,
        depth: int = 32,
        widen_factor: int = 4,
        num_classes: int = 10,
        lr: float = 0.1,
        momentum: float = 0.9,
        weight_decay: float = 5e-4,
    ):
        super(WideResnetLit, self).__init__()
        self.save_hyperparameters()

        self.model = WideResNet(
            depth=depth,
            widen_factor=widen_factor,
            num_classes=num_classes,
        )
        self.criterion = nn.CrossEntropyLoss()

        self.train_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        logits = self(inputs)
        loss = self.criterion(logits, targets)

        preds = torch.argmax(logits, dim=1)
        acc = self.train_acc(preds, targets)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log(
            "lr", self.optimizers().param_groups[0]["lr"], on_step=True, on_epoch=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        logits = self(inputs)
        loss = self.criterion(logits, targets)

        preds = torch.argmax(logits, dim=1)
        acc = self.val_acc(preds, targets)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        logits = self(inputs)
        loss = self.criterion(logits, targets)

        preds = torch.argmax(logits, dim=1)
        acc = self.test_acc(preds, targets)

        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
            momentum=self.hparams.momentum,
        )
        total_steps = self.trainer.estimated_stepping_batches
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=1.0,
            total_steps=total_steps,
            div_factor=10,
            final_div_factor=23,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
                "strict": False,
            },
        }

In [6]:
IMAGE_SIZE = 32
mean, std = [0.4465], [0.261]
# source: https://pytorch.org/vision/stable/transforms.html
transforms_train = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)


transforms_test = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)

train_dataset = datasets.Imagenette(
    root="./data", split="train", download=True, transform=transforms_train
)
test_dataset = datasets.Imagenette(
    root="./data", split="val", download=True, transform=transforms_test
)

N = len(train_dataset)
num_val = int(0.2 * N)
indices = torch.randperm(N)[:num_val]
mask = torch.ones(N, dtype=torch.bool)
mask[indices] = False
train_indices = torch.nonzero(mask, as_tuple=False).squeeze(1)
validation_dataset = torch.utils.data.Subset(train_dataset, indices=indices)
train_dataset = torch.utils.data.Subset(train_dataset, indices=train_indices)
validation_loader = torch.utils.data.DataLoader(
    dataset=validation_dataset, batch_size=512, num_workers=1, persistent_workers=True
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=512, num_workers=1, persistent_workers=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=512, num_workers=1, persistent_workers=True
)

100%|██████████| 1.56G/1.56G [01:52<00:00, 13.9MB/s]


In [7]:
class DataModule(L.LightningDataModule):
    def __init__(self, train_loader, validation_loader, test_loader):
        super().__init__()
        self.train_loader = train_loader
        self.validation_loader = validation_loader
        self.test_loader = test_loader

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def test_dataloader(self):
        return self.test_loader


data = DataModule(train_loader, validation_loader, test_loader)
model = WideResnetLit(depth=32, weight_decay=10e-4)
logger = MLFlowLogger(experiment_name="WideResnet", save_dir="mlruns")
trainer = L.Trainer(
    max_epochs=50,
    logger=logger,
    callbacks=[
        ModelCheckpoint(
            monitor="val_acc",
            mode="max",
            dirpath="checkpoints/wide_resnet",
            filename="{epoch:02d}-{val_acc:.3f}",
        )
    ],
    precision="16-mixed",
    num_sanity_val_steps=0,
)
trainer.fit(model, datamodule=data)

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loading `train_dataloader` to estimate number of stepping batches.
INFO:lightning.pytorch.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (15) is smaller than the logging interval Trainer(log_ever

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=50` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


Val accuracy is 0.110

In [8]:
best_ckpt = trainer.checkpoint_callback.best_model_path
print("Best checkpoint path:", best_ckpt)
trainer.test(model, datamodule=data, ckpt_path=best_ckpt)

INFO: Restoring states from the checkpoint path at /content/checkpoints/wide_resnet/epoch=01-val_acc=0.110.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/checkpoints/wide_resnet/epoch=01-val_acc=0.110.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/checkpoints/wide_resnet/epoch=01-val_acc=0.110.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/checkpoints/wide_resnet/epoch=01-val_acc=0.110.ckpt


Best checkpoint path: /content/checkpoints/wide_resnet/epoch=01-val_acc=0.110.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': nan, 'test_acc': 0.09834394603967667}]

In [9]:
#zip

In [10]:
from google.colab import files
import shutil
shutil.make_archive("mlruns", "zip", "mlruns")
files.download("mlruns.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>