In [1]:
import torch
import numpy as np

from ssunet.config import (
    PathConfig,
    SingleVolumeConfig,
    SplitParams,
    ModelConfig,
    LoaderConfig,
    TrainConfig,
    load_yaml,
    load_config,
)
from ssunet.dataloader import BinomDataset
import torch.utils.data as dt
from lightning.pytorch.loggers import TensorBoardLogger

import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    EarlyStopping,
    DeviceStatsMonitor,
)

from ssunet.models import SSUnet

example_data = np.ones((512, 512, 512)).astype(np.float32)
data_config = SingleVolumeConfig(example_data, 32, 32)
split_params = SplitParams()
model_config = ModelConfig()
loader_config = LoaderConfig()
train_config = TrainConfig()

train_data = BinomDataset(data_config, split_params=split_params)

train_loader = dt.DataLoader(train_data, **loader_config.to_dict)
test_name = train_config.name
logger = TensorBoardLogger(save_dir="./model_dir", name="test_name")


trainer = pl.Trainer(
    default_root_dir="./model_dir",
    accelerator="cuda",
    gradient_clip_val=1,
    precision=train_config.precision,  # type: ignore
    devices=[train_config.device_number],
    max_epochs=train_config.epochs,
    callbacks=[
        ModelCheckpoint(
            save_weights_only=True,
            mode="min",
            monitor="val_loss",
            save_top_k=2,
        ),
        LearningRateMonitor("epoch"),
        # EarlyStopping("val_loss", patience=25),
        # DeviceStatsMonitor(),
    ],
    logger=logger,  # type: ignore
    profiler="simple",
    limit_val_batches=20,
    log_every_n_steps=20,
    # enable_model_summary=True,
    # enable_checkpointing=True,
)
print(f"input_size: {tuple(next(iter(train_loader))[1].shape)}")

model = SSUnet(model_config)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


input_size: (20, 1, 32, 32, 32)


In [3]:
trainer.fit(model, train_loader, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type           | Params
------------------------------------------------
0 | down_convs   | ModuleList     | 3.5 M 
1 | up_convs     | ModuleList     | 1.5 M 
2 | conv_final   | Sequential     | 25    
3 | _psnr_metric | InferenceModel | 0     
4 | _ssim_metric | InferenceModel | 0     
------------------------------------------------
5.0 M     Trainable params
0         Non-trainable params
5.0 M     Total params
20.008    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

c:\Users\HEQ\mambaforge\envs\GAP3D\Lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
