This notebook serves the same function as main.py

In [1]:
%autoreload 2

config_file = '../binary_config.yaml'

max_epochs = 20

In [2]:
import sys
sys.path.append("..")

import logging
import os
from argparse import ArgumentParser
from datetime import datetime as dt

import pytorch_lightning as pl
import torch
import yaml
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms

from src.models.LightningBaseModel import LightningModel
from src.utils import CONFIG
from src.utils.DataLoader import PlanktonDataLoader

In [3]:
def load_config():
    with open(os.path.abspath(config_file), "r") as f:
        config_dict = yaml.safe_load(f)
        CONFIG.update(config_dict)

In [4]:
load_config()

# log directory is one level up since we're in the notebooks directory
CONFIG.tensorboard_logger_logdir = os.path.join('..', CONFIG.tensorboard_logger_logdir)

In [5]:
torch.manual_seed(CONFIG.random_seed)
pl.seed_everything(CONFIG.random_seed)

if CONFIG.debug_mode:
    logging.basicConfig(level=logging.DEBUG, format='%(name)s %(funcName)s %(levelname)s %(message)s')
else:
    logging.basicConfig(level=logging.WARNING, format='%(name)s %(funcName)s %(levelname)s %(message)s')

if CONFIG.debug_mode:
    torch.autograd.set_detect_anomaly(True)

logging.warning(CONFIG.__dict__)  # prints the whole config used for that run

transform = transforms.Compose([
    transforms.Pad(CONFIG.final_image_size),
    transforms.CenterCrop([CONFIG.final_image_size, CONFIG.final_image_size]),
    transforms.ToTensor(),
])

data_module = PlanktonDataLoader.from_argparse_args(CONFIG, transform=transform)
data_module.setup()

for batch in data_module.train_dataloader():
    example_input, _, _ = batch
    break

# if the model is trained on GPU add a GPU logger to see GPU utilization in comet-ml logs:
if CONFIG.gpus == 0:
    callbacks = None
else:
    callbacks = [pl.callbacks.GPUStatsMonitor()]

# logging to tensorboard:
experiment_name = f"{CONFIG.experiment_name}_{dt.now().strftime('%d%m%YT%H%M%S')}"
test_tube_logger = pl_loggers.TestTubeLogger(save_dir=CONFIG.tensorboard_logger_logdir,
                                             name=experiment_name,
                                             create_git_tag=False,
                                             log_graph=True)

# initializes a callback to save the 5 best model weights measured by the lowest loss:
checkpoint_callback = ModelCheckpoint(monitor="NLL Validation",
                                      save_top_k=5,
                                      mode='min',
                                      save_last=True,
                                      dirpath=os.path.join(CONFIG.checkpoint_file_path, experiment_name),
                                      )

model = LightningModel(class_labels=data_module.unique_labels,
                       all_labels=data_module.all_labels,
                       example_input_array=example_input,
                       **CONFIG.__dict__)

Global seed set to 42
Load new data: 100%|██████████| 21/21 [00:00<00:00, 26.57it/s]
Traceback (most recent call last):
  File "/gpfs/home/greenber/anaconda3/envs/plankton/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/gpfs/home/greenber/anaconda3/envs/plankton/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/gpfs/home/greenber/anaconda3/envs/plankton/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/gpfs/home/greenber/anaconda3/envs/plankton/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


In [6]:
CONFIG.max_epochs = max_epochs
trainer = pl.Trainer.from_argparse_args(CONFIG,
                                        callbacks=callbacks,
                                        logger=[test_tube_logger],
                                        checkpoint_callback=checkpoint_callback,
                                        log_every_n_steps=CONFIG.log_interval,
                                        )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [7]:
print(f'{len(data_module.train_data.files)} training, {len(data_module.valid_data.files)} validation, {len(data_module.test_data.files)} testing images')

print(f'{len(model.class_labels)} classes')

143350 training, 40958 validation, 20479 testing images
2 classes


In [None]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name          | Type       | Params | In sizes          | Out sizes
-----------------------------------------------------------------------------
0 | model         | Sequential | 25.6 M | [16, 3, 128, 128] | [16, 2]  
1 | loss_func     | NLLLoss    | 0      | ?                 | ?        
2 | accuracy_func | Accuracy   | 0      | ?                 | ?        
-----------------------------------------------------------------------------
25.6 M    Trainable params
0         Non-trainable params
25.6 M    Total params
102.236   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]