# 5.3

prompt:

Cifar 圖片分類 vgg16, vgg19  pretrained (pytorch lightning) to execute on colab

安装pytorch-lightning

In [15]:
import torch
import torchvision
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import vgg16, vgg19
from pytorch_lightning import LightningModule, Trainer

# Define data transformations for training and testing
transform = transforms.Compose([
    transforms.Resize(224),  # VGG models expect 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Pretrained mean and std
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create DataLoader for training and testing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

class VGGModel(LightningModule):
    def __init__(self, model_type='vgg16'):
        super(VGGModel, self).__init__()
        if model_type == 'vgg16':
            self.model = vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
        elif model_type == 'vgg19':
            self.model = vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)

        self.model.classifier[6] = torch.nn.Linear(4096, 10)

        # Initialize lists to store validation results
        self.val_losses = []
        self.val_accs = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()

        # Store loss and accuracy for aggregation later
        self.val_losses.append(loss)
        self.val_accs.append(acc)

    def on_validation_epoch_end(self):
        # Aggregate the validation results (loss and accuracy)
        avg_loss = torch.stack(self.val_losses).mean()
        avg_acc = torch.stack(self.val_accs).mean()

        # Log the results
        self.log('val_loss', avg_loss)
        self.log('val_acc', avg_acc)

        # Clear the lists after logging to prepare for the next epoch
        self.val_losses.clear()
        self.val_accs.clear()

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.0001)



# Choose model type
model = VGGModel(model_type='vgg16')

# Set up the trainer with the updated arguments
trainer = Trainer(max_epochs=5, devices=1 if torch.cuda.is_available() else 0, accelerator="gpu" if torch.cuda.is_available() else "cpu")

# Train the model
trainer.fit(model, train_loader, test_loader)


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | VGG  | 134 M  | train
---------------------------------------
134 M     Trainable params
0         Non-trainable params
134 M     Total params
537.206   Total estimated model params size (MB)
42        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
