In [3]:
import sys
import logging
logger = logging.getLogger(__name__)

# --------- midaGAN imports ----------
try:
    import midaGAN
except ImportError:
    logger.warning("midaGAN not installed as a package, importing it from the local directory.")
    sys.path.append('../..')
    import midaGAN

from midaGAN.trainer import Trainer
from midaGAN.conf.builders import build_training_conf
from midaGAN.utils import communication, environment

from omegaconf import OmegaConf
from midaGAN.conf import init_config
from midaGAN.data import build_dataset


In [4]:
communication.init_distributed()  # inits distributed mode if ran with torch.distributed.launch

conf = init_config("/workspace/data/midaGAN/projects/maastro_cbct_to_ct/experiments/loader_test.yaml")
print(conf.pretty())


dataset = build_dataset(conf)


batch_size: 1
project_dir: ../../projects/maastro_cbct_to_ct
seed: 0
use_cuda: true
mixed_precision: true
opt_level: O1
dataset:
  name: CBCTtoCTDataset
  root: /workspace/ibro/lung1_ct_cbct_nrrd_resampled/
  shuffle: true
  num_workers: 16
  patch_size:
  - 48
  - 224
  - 224
  hounsfield_units_range:
  - -1000
  - 2000
  focal_region_proportion: 0.2
gan:
  is_train: true
  name: PiCycleGAN
  loss_type: lsgan
  norm_type: instance
  weight_init_type: normal
  weight_init_gain: 0.02
  pool_size: 50
generator:
  name: Vnet3D
  in_channels: 1
  use_memory_saving: false
  use_inverse: true
  first_layer_channels: 16
  down_blocks:
  - 2
  - 2
  - 3
  up_blocks:
  - 3
  - 3
  - 3
  is_separable: false
n_iters: 20000
n_iters_decay: 20000
discriminator:
  name: PatchGAN3D
  in_channels: 1
  ndf: 64
  n_layers: 2
optimizer:
  beta1: 0.5
  lr_D: 0.0002
  lr_G: 0.0004
  lambda_A: 25.0
  lambda_B: 25.0
  lambda_identity: 0.0
  lambda_inverse: 0.0
  proportion_ssim: 0.84
  ssim_type: SSIM
logging

In [8]:
for i, data in zip(range(5), dataset):
    A = data['A']
    print(A.shape)
    B = data['B']
    print(B.shape)
    

torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
torch.Size([1, 48, 224, 224])
