In [1]:
from nucli_train.data_management.builders import build_data
from nucli_train.models.builders import build_model, MODEL_REGISTRY, MODEL_BUILDERS_REGISTRY

from nucli_train.training import Trainer
from nucli_train.nets.builders import NETWORK_REGISTRY, ARCHITECTURE_BUILDERS_REGISTRY

print(list(NETWORK_REGISTRY._dict.keys()))
print(list(ARCHITECTURE_BUILDERS_REGISTRY._dict.keys()))

import numpy as np
import yaml

[]
['unet', 'decoder']


  @torch.cuda.amp.autocast(enabled=False)  # keep Dice in fp32 for stability


In [2]:
import sys
print(sys.executable)

/home/ybr/miniconda3/envs/ssl-3d/bin/python


In [1]:
import sys

from ssl_3d.preprocess.preprocessor import PreprocessorBlosc2


ModuleNotFoundError: No module named 'nucli_train.data_management.create_dataset'

In [None]:
class MAEPreprocessor(PreprocessorBlosc2): 
    def __init__(self, **kwargs): 
        super().__init__(**kwargs)
    
    def exclude_condition(self, nifti_filename):
        return nifti_filename.endswith("0000.nii.gz") 
    
    def identify_tracer(self, nifti_filename):
        if "fdg" in nifti_filename:
            return "fdg"
        elif "psma" in nifti_filename:
            return "psma"
        return "unknown"
        
    
    def identify_center(self, nifti_filename):
        return "autopet_center"
    


kwargs = {
    "dataset_name": "autopet_2024",
    "dataset_path": "/Users/adammesbahi/Desktop/nucli-ssl/data/autopet_2024",
    "exp_name": "MAE_adam_experiment",
    "nifti_input_rootdir": "/Users/adammesbahi/Desktop/nucli-ssl/data/autopet_2024/imagesTr",
    "nifti_target_rootdir": None,
    "percentage_dataset": 1.0,
    "train_val_percentage": 0.8,
    "spacing": (1.0, 1.0, 1.0),
    "batch_size_train": 2,
    "batch_size_val": 2,
    "num_workers_train": 1,
    "num_workers_val": 1,
    "global_eval_interval": 10,
    "patch_size": (64, 64, 64),
    "shuffle_pick": True,
    "validation_evaluator": "save-preds-MIM", 
}


In [None]:
my_preprocessor = MAEPreprocessor(**kwargs)

In [None]:
from nucli_train.models.image_translation import ImageTranslationModel

import torch
import mlflow

def create_blocky_mask(tensor_size, block_size, sparsity_factor=0.75, rng_seed: None | int = None) -> torch.Tensor:
    """
    This function creates the mask used in MAE. 
    The image returned is not the masked image but the mask we will multiply by the image. 
    """

    # Calculate the size of the smaller mask
    small_mask_size = tuple(size // block_size for size in tensor_size)

    # Create the smaller mask
    flat_mask = torch.ones(np.prod(small_mask_size))
    n_masked = int(sparsity_factor * flat_mask.shape[0])
    if rng_seed is None:
        mask_indices = torch.randperm(flat_mask.shape[0])[:n_masked]
    else:
        gen = torch.Generator.manual_seed(rng_seed)
        mask_indices = torch.randperm(flat_mask.shape[0], generator=gen)[:n_masked]
    flat_mask[mask_indices] = 0
    small_mask = torch.reshape(flat_mask, small_mask_size)
    return small_mask


class MIM(ImageTranslationModel):
    @staticmethod
    def mask_creation(
        batch_size: int,
        patch_size: tuple[int, int, int],
        mask_percentage: float,
        rng_seed: int | None = None,
        block_size: int = 16,
    ) -> torch.Tensor:
        """
        Creates a masking tensor with 1s (indicating no masking) and 0s (indicating masking).
        The mask has to be of same size like the input data (batch_size, 1, x, y, z).

        :param batch_size: batch size during training
        :param patch_size: The 3D shape information for the input patch.
        :param mask_percentage: percentage of the patch that should be masked
        :param block_size: size of the blocks that should be masked
        :return:
        """

       

        sparsity_factor = mask_percentage
        mask = [create_blocky_mask(patch_size, block_size, sparsity_factor) for _ in range(batch_size)]
        mask = torch.stack(mask)[:, None, ...]  # Add channel dimension
        return mask
        
    def train_step(self, batch):
        print(batch.keys())
        data = batch['input'].cuda()
        print(f"Data shape: {data.shape}")

        mask = self.mask_creation(data.shape[0], [128, 128, 128], 0.75).cuda()
        
        rep_D, rep_H, rep_W = (
            data.shape[2] // mask.shape[2],
            data.shape[3] // mask.shape[3],
            data.shape[4] // mask.shape[4],
        )

        mask = mask.repeat_interleave(rep_D, dim=2).repeat_interleave(rep_H, dim=3).repeat_interleave(rep_W, dim=4)
        masked_data = data * mask

        output = self.network(masked_data)

        losses = self.get_losses(output, data)

        return losses

    def validation_step(self, batch):
        data = batch['input'].cuda()

        mask = self.mask_creation(data.shape[0], [128, 128, 128], 0.75).cuda()
        
        rep_D, rep_H, rep_W = (
            data.shape[2] // mask.shape[2],
            data.shape[3] // mask.shape[3],
            data.shape[4] // mask.shape[4],
        )

        mask = mask.repeat_interleave(rep_D, dim=2).repeat_interleave(rep_H, dim=3).repeat_interleave(rep_W, dim=4)
        masked_data = data * mask

        with torch.no_grad():

            output = self.network(masked_data)

        losses = self.get_losses(output, data)

        outputs = output.detach().cpu() # should remove this at some point. Going to cpu only makes sense if we want to save images
        targets = data.detach().cpu()
        inputs = masked_data.detach().cpu()

        metrics = self.get_metrics(outputs, targets, inputs)

        return {"losses": losses, "metrics": metrics, "predictions": outputs, 'input' : inputs, 'original' : targets}

In [None]:
from nucli_train.nets import build_network
from nucli_train.models.losses import build_losses

@MODEL_BUILDERS_REGISTRY.register('MAE')
def build_MAE(cfg):
    network = build_network(cfg['args']['network'])

    losses = build_losses(cfg['args']['losses'])    

    return MIM(network, loss_functions=losses)

In [None]:
from nucli_train.val.evaluators import EVALUATORS_REGISTRY


import matplotlib.pyplot as plt

@EVALUATORS_REGISTRY.register('save-preds-MIM')
class SavePredictionMIM:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name

    def evaluate_batch(self, model_output, batch):
        self.masked, self.original, self.prediction = model_output['input'], model_output['original'], model_output['predictions']
    def log_epoch(self, epoch):
        fig, axs = plt.subplots(min(self.masked.shape[0], 3), 3)
        for i in range(min(self.masked.shape[0], 3)):
            axs[i, 0].imshow(self.masked[i, 0, :, 32, :].cpu().numpy(), cmap='gray', vmin=0.0, vmax=2.0)
            axs[i, 0].set_axis_off()
            axs[i, 1].imshow(self.prediction[i, 0, :, 32, :].cpu().numpy(), cmap='gray', vmin=0.0, vmax=2.0)
            axs[i, 1].set_axis_off()
            axs[i, 2].imshow(self.original[i, 0, :, 32, :].cpu().numpy(), cmap='gray', vmin=0.0, vmax=2.0)
            axs[i, 2].set_axis_off()
        plt.tight_layout()
        mlflow.log_figure(fig, artifact_file=f"{self.dataset_name}/predictions/epoch_{epoch}.png")
        plt.close(fig)

In [None]:
model = build_model('/Users/adammesbahi/Desktop/nucli-ssl/nucli-ssl/examples/MAE/MIM_model.yaml')

In [None]:
from nucli_train.data_management.builders import build_data
import nucli_train.data_management.transformations 
print(my_preprocessor.nucli_train_path + "/main.yaml")

train_data, val_loaders = build_data(my_preprocessor.nucli_train_path + "/main.yaml")

In [None]:
mlflow.end_run()

trainer = Trainer(model, train_data=train_data, val_loaders=val_loaders, run_name='base-deeper', experiment_name='MIM', save_interval=50, model_cfg_path='/Users/adammesbahi/Desktop/nucli-ssl/nucli-ssl/examples/MAE/MIM_model.yaml', data_cfg_path=my_preprocessor.nucli_train_path + "/main.yaml") 

In [None]:
trainer.run(1000)
