# DeadTree Train Notebook

In [4]:
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers.wandb import WandbLogger

import hydra
from omegaconf import DictConfig

from deadtrees.network.segmodel import SemSegment
from deadtrees.data.deadtreedata import DeadtreesDataModule
from deadtrees.visualization.helper import show

In [5]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
print(f"NVIDIA Cuda available:           {torch.cuda.is_available()}")
print(f"PyTorch Version:                 {torch.__version__}")
print(f"PyTorch Lightning Version:       {pl.__version__}")

NVIDIA Cuda available:           True
PyTorch Version:                 1.8.1+cu102
PyTorch Lightning Version:       1.2.10


## Instantiate DataModule

In [7]:
datamodule = DeadtreesDataModule(
    "../data/dataset/train_balanced_short/",
    pattern = "train-balanced-short-000*.tar",
    train_dataloader_conf = {'batch_size': 8, 'num_workers': 4},
    val_dataloader_conf = {'batch_size': 8, 'num_workers': 2},
    test_dataloader_conf = {'batch_size': 1, 'num_workers': 1},
)
datamodule.setup()

Shard size: 64 (estimate base on file: ../data/dataset/train_balanced_short/train-balanced-short-000000.tar)


## Instantiate the Model

In [9]:
train_config = DictConfig(
    dict(learning_rate = 0.0001, 
         run_test = False,
         tversky_beta = 0.7,
        )
)

network_config = DictConfig(
    dict(num_classes = 2,
         num_layers = 5,
         features_start = 64,
         bilinear = False,
        )
)
    
model = SemSegment(train_config, network_config)

## Instantiate a Trainer

In [10]:
# define some callbacks
model_checkpoint = ModelCheckpoint(
    monitor = "val/total_loss",
    mode = "min",
    save_top_k = 1,
    dirpath = "checkpoints/",
    filename = "{epoch:02d}"
)

early_stopping = EarlyStopping(
    monitor = "val/total_loss",
    mode = "min",
    patience = 10,
)

# define the Weights&Biases logger
wandb_logger = WandbLogger(
    project = "deadtrees",
    offline = False,
    job_type = "train",
    group = "",
    save_dir = ".",
)



In [11]:
EPOCHS = 100

In [12]:
trainer = Trainer(
    gpus=1,
    min_epochs=1,
    max_epochs=EPOCHS,
    precision=16,
    progress_bar_refresh_rate=10,
    terminate_on_nan=True,
    callbacks=[model_checkpoint, early_stopping],
    logger=[wandb_logger],
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


## Some experiments

> **NOTE:**
> Currently not working - skip ?!?

## Train the model

In [13]:
trainer.fit(model=model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mcwerner[0m (use `wandb login --relogin` to force relogin)



  | Name                | Type                | Params
------------------------------------------------------------
0 | layers              | ModuleList          | 31.0 M
1 | binary_tversky_loss | BinaryTverskyLossV2 | 0     
2 | ce_loss             | CrossEntropyLoss    | 0     
------------------------------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params
124.174   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

## Test the model

In [14]:
trainer.test()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/accuracy': 0.994203507900238,
 'test/binary_tversky_loss': 0.9592905044555664,
 'test/ce_loss': 0.02119305729866028,
 'test/dice_coeff': 0.5463292598724365,
 'test/total_loss': 0.49024176597595215}
--------------------------------------------------------------------------------


[{'test/dice_coeff': 0.5463292598724365,
  'test/accuracy': 0.994203507900238,
  'test/ce_loss': 0.02119305729866028,
  'test/binary_tversky_loss': 0.9592905044555664,
  'test/total_loss': 0.49024176597595215}]