In [1]:
# package needed for reading dcm xray files
import sys
!{sys.executable} -m pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/75/ac/ac03f1f3fa950d96ca52f07d33fdbf5add05f164c1ac4eae179231dfa93d/pytorch_lightning-0.7.5-py3-none-any.whl (233kB)
[K     |█▍                              | 10kB 26.1MB/s eta 0:00:01[K     |██▉                             | 20kB 3.1MB/s eta 0:00:01[K     |████▏                           | 30kB 3.7MB/s eta 0:00:01[K     |█████▋                          | 40kB 2.9MB/s eta 0:00:01[K     |███████                         | 51kB 3.2MB/s eta 0:00:01[K     |████████▍                       | 61kB 3.8MB/s eta 0:00:01[K     |█████████▉                      | 71kB 4.2MB/s eta 0:00:01[K     |███████████▎                    | 81kB 4.3MB/s eta 0:00:01[K     |████████████▋                   | 92kB 4.8MB/s eta 0:00:01[K     |██████████████                  | 102kB 4.7MB/s eta 0:00:01[K     |███████████████▍                | 112kB 4.7MB/s eta 0:00:01[K     |████████████████▉               | 

In [0]:
# download and unzip dataset
!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip hymenoptera_data.zip

In [0]:
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms, models
import os


In [0]:
# pretrained model
output_classes = 2

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, output_classes)

In [0]:
class LightningMNISTClassifier(pl.LightningModule):

    def __init__(self):
        super(LightningMNISTClassifier, self).__init__()

        self.model = model_ft

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        x = self.model(x)
      
        # probability distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)

        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}


    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        # called at the end of the validation epoch
        # outputs is an array with what you returned in validation_step for each batch
        # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}] 
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def prepare_data(self):

        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }

        data_dir = 'hymenoptera_data'
        image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                                  data_transforms[x])
                          for x in ['train', 'val']}
        self.data_train = image_datasets['train']
        self.data_val = image_datasets['val']

    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=64, num_workers=3)

    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=64, num_workers=3)

    # def test_dataloader(self):
        # return DataLoader(self,data_test, batch_size=64)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [0]:
from pytorch_lightning.loggers import TensorBoardLogger

# default logger used by trainer
logger = TensorBoardLogger(
    save_dir=os.getcwd(),
    name='lightning_logs'
)

In [28]:
from pytorch_lightning.callbacks import ModelCheckpoint

# default used by the Trainer
checkpoint_callback = ModelCheckpoint(
    filepath=os.getcwd(),
    save_top_k=True,
    verbose=True,
    monitor='val_loss',
    mode='min',
    prefix=''
)



In [29]:

# train
model = LightningMNISTClassifier()

# you can also set this to find best lr :   auto_lr_find=True
# set this while debugging: fast_dev_run=True
# set maximum epochs: max_epochs=1000
# to automatically find the max batch size that can fit gpu set: auto_scale_batch_size='binsearch'
trainer = pl.Trainer(auto_scale_batch_size='binsearch', logger=logger, checkpoint_callback=checkpoint_callback)

trainer.fit(model)

INFO:lightning:GPU available: True, used: False
INFO:lightning:
   | Name                        | Type              | Params
--------------------------------------------------------------
0  | model                       | ResNet            | 11 M  
1  | model.conv1                 | Conv2d            | 9 K   
2  | model.bn1                   | BatchNorm2d       | 128   
3  | model.relu                  | ReLU              | 0     
4  | model.maxpool               | MaxPool2d         | 0     
5  | model.layer1                | Sequential        | 147 K 
6  | model.layer1.0              | BasicBlock        | 73 K  
7  | model.layer1.0.conv1        | Conv2d            | 36 K  
8  | model.layer1.0.bn1          | BatchNorm2d       | 128   
9  | model.layer1.0.relu         | ReLU              | 0     
10 | model.layer1.0.conv2        | Conv2d            | 36 K  
11 | model.layer1.0.bn2          | BatchNorm2d       | 128   
12 | model.layer1.1              | BasicBlock        | 73 K  
13 | 

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00000: val_loss reached 19.39974 (best 19.39974), saving model to /content/epoch=0_v0.ckpt as top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00001: val_loss reached 16.61771 (best 16.61771), saving model to /content/epoch=1.ckpt as top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00002: val_loss reached 2.24247 (best 2.24247), saving model to /content/epoch=2.ckpt as top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00003: val_loss reached 1.48076 (best 1.48076), saving model to /content/epoch=3.ckpt as top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00004: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00005: val_loss reached 0.83626 (best 0.83626), saving model to /content/epoch=5.ckpt as top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00006: val_loss reached 0.68852 (best 0.68852), saving model to /content/epoch=6.ckpt as top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00007: val_loss reached 0.61935 (best 0.61935), saving model to /content/epoch=7.ckpt as top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00008: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00009: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00010: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00011: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00012: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00013: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00014: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00015: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00016: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00017: val_loss  was not in top True


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





1