In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.models import vgg16, VGG16_Weights
import pytorch_lightning as pl

torch.set_float32_matmul_precision("medium")

In [None]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_properties(0))

In [None]:
class MobileNetV2(pl.LightningModule):
    def __init__(self):
        super(MobileNetV2, self).__init__()
        self.model = vgg16(weights=VGG16_Weights.DEFAULT)
        num_ftrs = self.model.classifier[6].in_features
        self.model.classifier[6] = nn.Linear(num_ftrs, 10)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = nn.CrossEntropyLoss()(output, target)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = nn.CrossEntropyLoss()(output, target)

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
    def configure_optimizers(self):
        return optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='/tmp/cifar_train/', train=True, download=True, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)

val_dataset = datasets.CIFAR10(root='/tmp/cifar_val/', train=False, download=True, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=8)

model = MobileNetV2()
trainer = pl.Trainer(max_epochs=-1, callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.03, patience=3, mode="min")])
trainer.fit(model, train_loader, val_loader)