# DeadTree Train Notebook

In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers.wandb import WandbLogger

import hydra
from omegaconf import DictConfig

from deadtrees.network.segmodel import SemSegment
from deadtrees.data.deadtreedata import DeadtreesDataModule
from deadtrees.visualization.helper import show

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
print(f"NVIDIA Cuda available:           {torch.cuda.is_available()}")
print(f"PyTorch Version:                 {torch.__version__}")
print(f"PyTorch Lightning Version:       {pl.__version__}")

## Instantiate DataModule

In [None]:
datamodule = DeadtreesDataModule(
    "../data/dataset/train/",
    pattern = "train-balanced-000*.tar",
    pattern_extra = ["train-negativesamples-000*.tar", "train-randomsamples-000*.tar"],
    batch_size_extra = [1, 7],
    train_dataloader_conf = {'batch_size': 16, 'num_workers': 2},
    val_dataloader_conf = {'batch_size': 16, 'num_workers': 2},
    test_dataloader_conf = {'batch_size': 16, 'num_workers': 2},
)
datamodule.setup(in_channels=4, classes=2)

## Instantiate the Model

In [None]:
train_config = DictConfig(
    dict(learning_rate = 0.0003, 
         run_test = False,
        )
)

network_config = DictConfig(
    dict(
        # model definitions
        architecture = 'unet',
        encoder_name = 'resnet34',
        encoder_depth = 5,
        encoder_weights = "imagenet",
        # data specific settings
        classes = 2,
        in_channels = 4,
       )
)
    
model = SemSegment(train_config, network_config)
model.summarize(max_depth=1);

## Instantiate a Trainer

In [None]:
# define some callbacks
model_checkpoint = ModelCheckpoint(
    monitor = "val/total_loss",
    mode = "min",
    save_top_k = 1,
    dirpath = "checkpoints/",
    filename = "{epoch:02d}"
)

early_stopping = EarlyStopping(
    monitor = "val/total_loss",
    mode = "min",
    patience = 10,
)

# define the Weights&Biases logger
wandb_logger = WandbLogger(
    project = "deadtrees",
    offline = False,
    job_type = "train",
    group = "",
    save_dir = ".",
)

In [None]:
EPOCHS = 10

In [None]:
trainer = Trainer(
    gpus=1,
    min_epochs=1,
    max_epochs=EPOCHS,
    precision=16,
    progress_bar_refresh_rate=10,
    terminate_on_nan=True,
    callbacks=[model_checkpoint, early_stopping],
    logger=[wandb_logger],
    )

## Some experiments

> **NOTE:**
> Currently not working - skip ?!?

## Train the model

In [None]:
trainer.fit(model=model, datamodule=datamodule)

## Test the model

In [None]:
trainer.test()