In [1]:
from bts.mmformer.train import train_epoch
from bts.mmformer.mmformer.mmformer import Model
from bts.data.porcupine_dataset import PorcupineDataset

import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

from monai.transforms import (
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    RandFlipd,
    RandSpatialCropd,
    Resized,
    Transform,
)
from bts.data.utils import UnsqueezeDatad
from bts.data.dataset import ConvertToMultiChannelBasedOnEchidnaClassesd

In [2]:
dataset_path = "/home/desmin/data/porcupine_dataset/train"

train_dataset = PorcupineDataset(dataset_root_path=dataset_path, transform=Compose([
    LoadImaged(["image"], reader="NrrdReader"),
    LoadImaged(["label"], reader="NumpyReader"),
    UnsqueezeDatad(["image", "label"]),
    RandSpatialCropd(
        keys=["image", "label"],
        roi_size=[128, 128, 128],
        random_size=False,
    ),
    Resized(
        keys=["image", "label"],
        spatial_size=[128,128,128]
    )
]))
train_loader = DataLoader(train_dataset, batch_size=1, pin_memory=True)

In [3]:
# load model
class_num = 4

mmformer = Model(num_cls=class_num)
mmformer_parallel = torch.nn.DataParallel(mmformer).cuda()

weights_path = "/home/desmin/grad_project/bbm47980_bts/bts/mmformer/mmformer/out/model_last.pth"
mmformer_parallel.load_state_dict(torch.load(weights_path)["state_dict"])

<All keys matched successfully>

In [4]:
!nvidia-smi

Tue Jun 13 04:26:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro P5000        Off  | 00000000:3B:00.0 Off |                  Off |
| 26%   38C    P0    44W / 180W |   1145MiB / 16276MiB |     10%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Quadro P5000        Off  | 00000000:5E:00.0 Off |                  Off |
| 26%   31C    P8    11W / 180W |      2MiB / 16278MiB |      0%      Defaul

In [5]:
optim = Adam(mmformer_parallel.parameters(), lr=1e-4)

In [6]:
history_all = []
for epoch_idx in range(10):
    history = train_epoch(
        model = mmformer_parallel,
        loader = train_loader,
        optimizer = optim,
        epoch=epoch_idx + 1
    )
    history_all.append(history["Mean Train Loss"])
    
    torch.save(mmformer_parallel.module.state_dict(), "mmformer_model_weights.pth")
    print(f"EPOCH {epoch_idx+1}/10: {history['Mean Train Loss']}")

Epoch 1:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 1/10: 0.5239266915434211


Epoch 2:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 2/10: 0.48026738355498344


Epoch 3:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 3/10: 0.4571548505912166


Epoch 4:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 4/10: 0.4574287608265877


Epoch 5:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 5/10: 0.4476693458193858


Epoch 6:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 6/10: 0.45022268064276


Epoch 7:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 7/10: 0.4391207581264733


Epoch 8:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 8/10: 0.4289081849642759


Epoch 9:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 9/10: 0.43721846486689775


Epoch 10:   0%|          | 0/169 [00:00<?, ?it/s]

EPOCH 10/10: 0.43343812214023264
