# Imports

In [1]:
%%time
import pytorch_lightning as pl
import torch as th

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, GPUStatsMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.metrics.functional.classification import accuracy
from pytorch_lightning import seed_everything
import torch.nn.functional as F
import torchvision
import timm
from timm import create_model
from torchvision import transforms
import os
import pandas as pd

# data module
from src.dataset import DataModule

# model
from src.model import Model

# config file
from src.config import Config

CPU times: user 1.32 s, sys: 243 ms, total: 1.56 s
Wall time: 1.4 s


# Data module and setup

In [2]:
_ = seed_everything(seed =  Config.seed_val)

Global seed set to 2021


In [3]:
config_dict = Config.__dict__.items()
config_dict = dict([item for item in config_dict if '__' not in item[0]])

In [4]:
%%time

train_df = pd.read_csv(os.path.join(Config.data_dir, 'train.csv'))

data_transform = {
    'train': transforms.Compose([
        transforms.Resize(size=(Config.resize, Config.resize)),
        transforms.RandomHorizontalFlip(p=.7),
        transforms.RandomVerticalFlip(p=.3),
        transforms.RandomRotation(degrees=25),
        transforms.CenterCrop(size=(Config.img_h, Config.img_w)),
        transforms.ColorJitter(brightness=(0.4, 1), contrast=.2, saturation=0, hue=0),
        transforms.GaussianBlur(kernel_size=3)
    ]),
    
    'validation':transforms.Compose([
        transforms.Resize(size=(Config.resize, Config.resize)),
        transforms.RandomRotation(degrees=25),
        transforms.CenterCrop(size=(Config.img_h, Config.img_w)),
        transforms.ColorJitter(brightness=(0.45, 1), 
                               contrast=.1, 
                               saturation=.1, 
                               hue=0.1),
        transforms.GaussianBlur(kernel_size=3)
    ]), 
    
    'test':transforms.Compose([
        transforms.Resize(size=(Config.img_h, Config.img_w)),
        transforms.RandomRotation(degrees=25),
    ])
    
}

dm = DataModule(config=Config, 
                 train_data_dir=Config.train_data_dir, 
                 test_data_dir=Config.test_data_dir, 
                 train_df=train_df,
                 data_transform=data_transform,
                 validation_split=.2,
                 train_frac = 1)
dm.setup()

[INFO] Training on 17117
[INFO] Validating on 4280
CPU times: user 59.7 ms, sys: 30.8 ms, total: 90.5 ms
Wall time: 94.9 ms


In [5]:
%%time
model = Model(config=config_dict)

CPU times: user 300 ms, sys: 28.2 ms, total: 328 ms
Wall time: 256 ms


In [6]:
%%time

ckpt_cb = ModelCheckpoint(
    monitor='val_acc', 
    mode='max', 
    dirpath=Config.models_dir, 
    filename=f'{Config.base_model}-'+'leaf_disease_classifier-{val_acc:.5f}-{val_loss:.5f}'
)

gpu_stats = GPUStatsMonitor(
    memory_utilization=True, 
    gpu_utilization=True, 
    fan_speed=True, 
    temperature=True
)
es = EarlyStopping(
    monitor='val_acc', 
    patience=4, 
    mode='max'
)

Logger = TensorBoardLogger(
    save_dir=Config.logs_dir, 
    name='cassava_leaf_disease'
)

Callbacks = [es, ckpt_cb, gpu_stats]

trainer = pl.Trainer(
    gpus=-1, 
    max_epochs=5,#Config.num_epochs, 
    precision=16,
    callbacks=Callbacks,
    logger=Logger,
    #fast_dev_run=True
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


CPU times: user 34.9 ms, sys: 11.8 ms, total: 46.6 ms
Wall time: 41.8 ms


# Training phase

In [7]:
%%time
trainer.fit(model=model, datamodule=dm)


  | Name                  | Type         | Params
-------------------------------------------------------
0 | train_transforms      | Sequential   | 0     
1 | validation_transforms | Sequential   | 0     
2 | encoder               | EfficientNet | 9.1 M 
3 | classifier            | Linear       | 5.0 K 
4 | dropout               | Dropout      | 0     
-------------------------------------------------------
9.1 M     Trainable params
0         Non-trainable params
9.1 M     Total params


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



CPU times: user 21min 39s, sys: 1min 33s, total: 23min 13s
Wall time: 25min 42s


1

In [8]:
%load_ext tensorboard

In [9]:
%tensorboard --logdir ../logs

Reusing TensorBoard on port 6006 (pid 161949), started 0:30:11 ago. (Use '!kill 161949' to kill it.)