In [27]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [2]:
class CIFAR10():  #@save    
    def __init__(self, root, resize=(224, 224)):    
        trans = transforms.Compose([transforms.Resize(resize),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.train = torchvision.datasets.CIFAR10(
            root=root, train=True, transform=trans, download=True)
        # use 20% of training data for validation
        train_set_size = int(len(self.train) * 0.8)
        valid_set_size = len(self.train) - train_set_size
         # split the train set into two
        seed = torch.Generator().manual_seed(42)
        self.train, self.val = torch.utils.data.random_split(self.train, [train_set_size, valid_set_size], generator=seed)
        self.test = torchvision.datasets.CIFAR10(
            root=root, train=False, transform=trans, download=True)
        
dataset = CIFAR10(root="./data/")

train_dataloader = torch.utils.data.DataLoader(dataset.train, batch_size=64, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(dataset.val, batch_size=64, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(dataset.test, batch_size=64, shuffle=False)

print(f"Number of training examples: {len(dataset.train)}")
print(f"Number of validation examples: {len(dataset.val)}")
print(f"Number of test examples: {len(dataset.test)}")

Files already downloaded and verified
Files already downloaded and verified
Number of training examples: 40000
Number of validation examples: 10000
Number of test examples: 10000


In [25]:
%pip install pytorch-lightning
import pytorch_lightning as pl
import torch.nn as nn
from torchmetrics.functional import accuracy

# The model is passed as an argument to the `LightModel` class.
class LightModel(pl.LightningModule):
	def __init__(self,model,lr=1e-5):
		super().__init__()
		self.model = model
		self.lr = lr
	def training_step(self, batch):
		X, y = batch
		y_hat = self.model(X)
		loss = nn.functional.cross_entropy(y_hat, y)
		self.log("train_loss", loss)
		return loss
    
	def forward(self, x):
		return self.model(x)

	def validation_step(self, batch, batch_idx):
		x, y = batch
		logits = self(x)
		loss = F.cross_entropy(logits, y)
		preds = torch.argmax(logits, dim=1)
		acc = (preds == y).float().mean()
		self.log('val_loss', loss, prog_bar=True)
		self.log('val_acc', acc, prog_bar=True)
		return loss

	def test_step(self, batch):
		X, y = batch
		y_hat = self.model(X)
		preds = torch.argmax(y_hat, dim=1)
		acc = accuracy(preds, y, task="multiclass", num_classes=10)
		self.log("test_acc", acc)
		loss = nn.functional.cross_entropy(y_hat, y)		
		self.log("test_loss", loss)		
	def configure_optimizers(self):
		optimizer = torch.optim.Adam(self.parameters(), self.lr)
		return optimizer


arch = nn.Sequential(
    nn.Flatten(),
    nn.Linear(3 * 224 * 224, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

mlp = LightModel(arch)

Note: you may need to restart the kernel to use updated packages.


In [None]:
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Trainer

early_stopping = EarlyStopping(
    monitor='val_loss',  # metric to monitor
    patience=5,          # epochs with no improvement after which training will stop
    mode='min',          # mode for min loss; 'max' if maximizing metric
    min_delta=0.001      # minimum change to qualify as an improvement
)

trainer = Trainer(callbacks=[early_stopping], max_epochs=50)
trainer.fit(model=mlp, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 77.2 M
-------------------------------------
77.2 M    Trainable params
0         Non-trainable params
77.2 M    Total params
308.819   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# Evaluate the model on the test dataset
trainer.test(model=mlp, dataloaders=test_dataloader)

In [None]:
from torchvision.models import vgg16
vgg16_model = vgg16(weights="DEFAULT", progress=True)

for param in vgg16_model.parameters():
	param.requires_grad = False


vgg16_model.classifier = nn.Sequential(
    nn.Flatten(),
    nn.Linear(25088, 50), 
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 10)
)

print(vgg16_model)

vgg16_lightmodel = LightModel(vgg16_model, lr=0.001)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min',
    min_delta=0.001
)

In [None]:
trainer = Trainer(callbacks=[early_stopping], max_epochs=20)
trainer.fit(model=vgg16_lightmodel, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

trainer.test(vgg16_lightmodel, dataloaders=test_dataloader)