In [2]:
%load_ext autoreload
%autoreload 2

import sys
import lightning.pytorch as pl
import torch
from matplotlib import pyplot as plt
import numpy as np

sys.path.insert(1, sys.path[0] + '/..')
from src.data.datamodule import DataModule
from src.model.setup import setup_model
from src.misc.utils import set_seed_and_precision

from src.run import parse_option
args = parse_option(notebook=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Cell below should be the same as `main()` in `run.py`

In [6]:
set_seed_and_precision(args)

datamodule = DataModule(dir = '../data_dev', dataset = 'boxes', num_workers=args.num_workers, batch_size=args.batch_size)
model = setup_model(net = args.net)

trainer = pl.Trainer(
    fast_dev_run=True,
    logger = pl.loggers.TensorBoardLogger('../logs', name = 'test', version = args.version),
    max_epochs=args.max_epochs,
    log_every_n_steps=1,
    accelerator = 'gpu' if torch.cuda.is_available() else 'cpu',
    callbacks = [
            pl.callbacks.TQDMProgressBar(refresh_rate = 1000)
            ],
    deterministic = False, # Set to False for max_pool3d_with_indices_backward_cuda
)

trainer.fit(model,  datamodule=datamodule)

Global seed set to 42
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name | Type             | Params
------------------------------------------
0 | net  | UNet3D           | 5.0 M 
1 | loss | CrossEntropyLoss | 0     
------------------------------------------
5.0 M     Trainable params
0         Non-trainable params
5.0 M     Total params
20.089    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] 
x.shape   torch.Size([1, 128, 128, 128]) 
y.shape   torch.Size([1, 128, 128, 128]) 
y_hat.shape torch.Size([1, 128, 128, 128])
Epoch 0: 100%|██████████| 1/1 [00:13<00:00, 13.20s/it, train_loss=0.681, train_acc=0.000427]
x.shape   torch.Size([1, 128, 128, 128]) 
y.shape   torch.Size([1, 128, 128, 128]) 
y_hat.shape torch.Size([1, 128, 128, 128])
Epoch 0: 100%|██████████| 1/1 [00:17<00:00, 17.59s/it, train_loss=0.681, train_acc=0.000427, val_loss=0.411, val_acc=0.999]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:17<00:00, 17.59s/it, train_loss=0.681, train_acc=0.000427, val_loss=0.411, val_acc=0.999]


### Dev

In [29]:
datamodule = DataModule(dir = '../data_dev', dataset = 'boxes', num_workers=args.num_workers, batch_size=args.batch_size, splits = ['train', 'val'])
datamodule.setup()
example = datamodule.train_dataloader().dataset[0]
example[0].shape

torch.Size([64, 64, 64])

In [31]:
from src.model.models import UNet3D

In [32]:
net = UNet3D(1,1)