In [1]:
from src.ssunet.datasets import BinomDataset, ValidationDataset
from src.ssunet.models import Bit2Bit
from src.ssunet.configs import load_config
from src.ssunet.constants import DEFAULT_CONFIG_PATH

config = load_config(DEFAULT_CONFIG_PATH)

# Load data for training and validation, ground truth for validation metrics
data = config.path_config.load_data_only()
validation_data = config.path_config.load_reference_and_ground_truth()

# Load data configurations and disable augmentation for validation
data_config = config.data_config
validation_config = data_config.validation_config

# Create training and validation datasets
training_data = BinomDataset(data, data_config, config.split_params)
validation_data = ValidationDataset(validation_data, validation_config)

# Create model
model = Bit2Bit(config.model_config)

# Create data loaders
training_loader = config.loader_config.loader(training_data)
validation_loader = config.loader_config.loader(validation_data)

# Print the input size
print(f"input_size: {tuple(next(iter(training_loader))[1].shape)}")

# Train the model
trainer = config.train_config.trainer
trainer.fit(model, training_loader, validation_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


input_size: (3, 1, 32, 256, 256)



  | Name        | Type                             | Params | Mode 
-------------------------------------------------------------------------
0 | psnr_metric | PeakSignalNoiseRatio             | 0      | train
1 | ssim_metric | StructuralSimilarityIndexMeasure | 0      | train
2 | down_convs  | ModuleList                       | 9.6 M  | train
3 | up_convs    | ModuleList                       | 5.1 M  | train
4 | conv_final  | Sequential                       | 33     | train
-------------------------------------------------------------------------
14.7 M    Trainable params
0         Non-trainable params
14.7 M    Total params
58.718    Total estimated model params size (MB)
100       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [2]:
trainer.save_checkpoint(config.train_config.default_root_dir / "model.ckpt")
