In [1]:
import torch
import pytorch_lightning as pl
import torchvision
from torchvision.datasets import CIFAR10
from matplotlib import pyplot as plt
import numpy as np
from torch.nn import functional as F
from torch import nn



In [2]:
print(torch.cuda.is_available())

True


In [45]:
class CIFARCLassifier(pl.LightningModule):
    def __init__(self):
        super(CIFARCLassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def prepare_data(self):
        #download the data and normalize them
        self.cifar_train = CIFAR10(root="data/", train=True, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.5, 0.5)]))

        self.cifar_val = CIFAR10(root="data/", train=False, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.5, 0.5)]))

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.cifar_train, batch_size=128, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.cifar_val, batch_size=128, shuffle=False)


    def training_step(self, train_batch, batch_idx):
        #1 batch of training
        data, label = train_batch
        logits = self.forward(data)
        loss = F.cross_entropy(logits, label)
        result = pl.TrainResult(loss)
        prediction = logits.argmax(dim=1, keepdim=True).squeeze()

        correct =  prediction.eq(label.view_as(prediction)).sum().item()
        return {'loss' : loss, 'correct' : correct}
    
    def training_epoch_end(self, outputs):
        #end of an epoch
        #calculate the average loss of this epoch
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        accuracy = 100 * (sum([x['correct'] for x in outputs]) / float(len(self.cifar_train)))


        logs = {'train_loss': avg_loss, 'train_accuracy': accuracy}
        return {'loss' : avg_loss, 'log' : logs}

    def validation_step(self, validation_batch, batch_idx):
        data, label = validation_batch
        logits = self.forward(data)
        loss = F.cross_entropy(logits, label)
        prediction = logits.argmax(dim=1, keepdim=True).squeeze()

        correct =  prediction.eq(label.view_as(prediction)).sum().item()
        return {'loss' : loss, 'correct' : correct}

    def validation_epoch_end(self, outputs):
        #end of an epoch
        #calculate the average loss of this epoch
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        accuracy = 100 * (sum([x['correct'] for x in outputs]) / float(len(self.cifar_train)))

        logs = {'validation_loss': avg_loss, 'validation_accuracy': accuracy}
        return {'loss' : avg_loss, 'log' : logs}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)





In [46]:
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger('tb_logs', name="cifar10_documentation_model")

model = CIFARCLassifier()
model.prepare_data()
model.train_dataloader()
trainer = pl.Trainer(max_epochs=10, logger=logger, gpus=[0], fast_dev_run=False)

trainer.fit(model)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type      | Params
------------------------------------
0 | conv1 | Conv2d    | 456   
1 | pool  | MaxPool2d | 0     
2 | conv2 | Conv2d    | 2 K   
3 | fc1   | Linear    | 48 K  
4 | fc2   | Linear    | 10 K  
5 | fc3   | Linear    | 850   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..



1