Imports

In [None]:
import torch
import argparse
import logging
# from tensorboardX import SummaryWriter
import os
import numpy as np

from io import BytesIO
from PIL import Image
from torch.utils.data import Dataset
import random
from re import split
import torch.utils.data

Config

In [None]:
opt = {
    "name": "denoising_v1",
    "phase": "train",
    "gpu_ids": [0, 1], 
    "debug": False,  
    "enable_wandb": False,  
    "log_wandb_ckpt": False,  
    "log_eval": False,  
    "path": {
        "log": "logs",
        "tb_logger": "tb_logger",
        "results": "results",
        "checkpoint": "checkpoint",
        "resume_state": None
    },
    "model": {
        "which_model_G": "sr3",
        "finetune_norm": False,
        "unet": {
            "in_channel": 6,
            "out_channel": 3,
            "inner_channel": 64,
            "norm_groups": 16,
            "channel_multiplier": [1, 2, 4, 8, 16],
            "attn_res": [],
            "res_blocks": 1,
            "dropout": 0
        },
        "beta_schedule": {
            "train": {
                "schedule": "linear",
                "n_timestep": 2000,
                "linear_start": 1e-6,
                "linear_end": 1e-2
            },
            "test": {
                "schedule": "linear",
                "n_timestep": 2000,
                "linear_start": 1e-6,
                "linear_end": 1e-2
            }
        },
        "diffusion": {
            "image_size": 512,
            "channels": 3,
            "conditional": True
        }
    },
    "train": {
        "n_iter": 1000000,
        "val_freq": 1e4,
        "save_checkpoint_freq": 1e4,
        "print_freq": 50,
        "optimizer": {
            "type": "adam",
            "lr": 3e-6
        },
        "ema_scheduler": {
            "step_start_ema": 5000,
            "update_ema_every": 1,
            "ema_decay": 0.9999
        }
    },
    "wandb": {
        "project": "denoising_v1"
    }
}

Logger

In [None]:
# # logging
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.benchmark = True

# Logger.setup_logger(None, opt['path']['log'],
#                     'train', level=logging.INFO, screen=True)
# Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
# logger = logging.getLogger('base')
# logger.info(Logger.dict2str(opt))
# tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

# # Initialize WandbLogger
# if opt['enable_wandb']:
#     import wandb
#     wandb_logger = WandbLogger(opt)
#     wandb.define_metric('validation/val_step')
#     wandb.define_metric('epoch')
#     wandb.define_metric("validation/*", step_metric="val_step")
#     val_step = 0
# else:
#     wandb_logger = None

Dataset creation

In [None]:
class PairwiseDataset(Dataset):
    def __init__(self, noisy_images_paths: list, gt_images_paths: list):
        """Initialize fMRI dataset for denoising.
        
        Args:
            noisy_images_paths (list): List of paths to noisy fMRI volumes (.npy files)
            gt_images_paths (list): List of paths to ground truth fMRI volumes (.npy files)
        """
        self.noisy_data = []
        self.gt_data = []
        
        # Load and process noisy data
        if noisy_images_paths:
            for path in noisy_images_paths:
                data = np.load(path)  # Load 4D array (x, y, z, t)
                # Reshape to collapse last 2 dimensions into one
                reshaped_data = np.reshape(data, (data.shape[0], data.shape[1], -1))
                self.noisy_data.append(reshaped_data)
            # Concatenate all arrays along the third dimension
            self.noisy_data = np.concatenate(self.noisy_data, axis=2)
            
        # Load and process ground truth data
        if gt_images_paths:
            for path in gt_images_paths:
                data = np.load(path)  # Load 4D array (x, y, z, t)
                # Reshape to collapse last 2 dimensions into one
                reshaped_data = np.reshape(data, (data.shape[0], data.shape[1], -1))
                self.gt_data.append(reshaped_data)
            # Concatenate all arrays along the third dimension
            self.gt_data = np.concatenate(self.gt_data, axis=2)
            
        self.data_len = self.noisy_data.shape[2] if len(self.noisy_data) > 0 else 0

    def __len__(self):
        return self.data_len

    def __getitem__(self, index):
        # Select the i-th noisy and ground truth images
        noisy_image = self.noisy_data[:, :, index]
        gt_image = self.gt_data[:, :, index]
        return {'GT': gt_image, 'Noisy': noisy_image, 'Index': index}

In [None]:
def create_dataloader(dataset, dataset_opt, phase):
    '''create dataloader '''
    if phase == 'train':
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=dataset_opt['batch_size'],
            shuffle=dataset_opt['use_shuffle'],
            num_workers=dataset_opt['num_workers'],
            pin_memory=True)
    elif phase == 'val':
        return torch.utils.data.DataLoader(
            dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
    else:
        raise NotImplementedError(
            'Dataloader [{:s}] is not found.'.format(phase))

In [None]:
# Train data paths
noisy_images_paths_train = ['/kaggle/input/fmri-train-1/data/noisy_func_train_1.npy',
                      '/kaggle/input/fmri-train-2/data/noisy_func_train_2.npy',
                      '/kaggle/input/fmri-train-3/data/noisy_func_train_3.npy']

gt_images_paths_train = ['/kaggle/input/fmri-train-1/data/gt_func_train_1.npy',
                   '/kaggle/input/fmri-train-2/data/gt_func_train_2.npy',
                   '/kaggle/input/fmri-train-3/data/gt_func_train_3.npy']

# Test data paths
noisy_images_paths_test = ['/kaggle/input/fmri-test/data/noisy_func_test.npy']

gt_images_paths_test = ['/kaggle/input/fmri-test/data/gt_func_test.npy']

In [None]:
# dataset
for phase, dataset_opt in opt['datasets'].items():
    if phase == 'train' and opt['phase'] != 'test':
        train_set = PairwiseDataset(noisy_images_paths_train, gt_images_paths_train)
        train_loader = create_dataloader(
            train_set, dataset_opt, phase)
    elif phase == 'test':
        test_set = PairwiseDataset(noisy_images_paths_test, gt_images_paths_test)
        test_loader = create_dataloader(
            test_set, dataset_opt, phase)
# logger.info('Initial Dataset Finished')

Model loading

In [None]:
# model
diffusion = Model.create_model(opt)
# logger.info('Initial Model Finished')

In [None]:
# Train
current_step = diffusion.begin_step
current_epoch = diffusion.begin_epoch
n_iter = opt['train']['n_iter']

if opt['path']['resume_state']:
    logger.info('Resuming training from epoch: {}, iter: {}.'.format(
        current_epoch, current_step))

diffusion.set_new_noise_schedule(
    opt['model']['beta_schedule'][opt['phase']], schedule_phase=opt['phase'])
if opt['phase'] == 'train':
    while current_step < n_iter:
        current_epoch += 1
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > n_iter:
                break
            diffusion.feed_data(train_data)
            diffusion.optimize_parameters()
            # log
            if current_step % opt['train']['print_freq'] == 0:
                logs = diffusion.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}> '.format(
                    current_epoch, current_step)
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

                if wandb_logger:
                    wandb_logger.log_metrics(logs)

    #         # validation
    #         if current_step % opt['train']['val_freq'] == 0:
    #             avg_psnr = 0.0
    #             idx = 0
    #             result_path = '{}/{}'.format(opt['path']
    #                                             ['results'], current_epoch)
    #             os.makedirs(result_path, exist_ok=True)

    #             diffusion.set_new_noise_schedule(
    #                 opt['model']['beta_schedule']['val'], schedule_phase='val')
    #             for _,  val_data in enumerate(val_loader):
    #                 idx += 1
    #                 diffusion.feed_data(val_data)
    #                 diffusion.test(continous=False)
    #                 visuals = diffusion.get_current_visuals()
    #                 sr_img = Metrics.tensor2img(visuals['SR'])  # uint8
    #                 hr_img = Metrics.tensor2img(visuals['HR'])  # uint8
    #                 lr_img = Metrics.tensor2img(visuals['LR'])  # uint8
    #                 fake_img = Metrics.tensor2img(visuals['INF'])  # uint8

    #                 # generation
    #                 Metrics.save_img(
    #                     hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
    #                 Metrics.save_img(
    #                     sr_img, '{}/{}_{}_sr.png'.format(result_path, current_step, idx))
    #                 Metrics.save_img(
    #                     lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx))
    #                 Metrics.save_img(
    #                     fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))
    #                 tb_logger.add_image(
    #                     'Iter_{}'.format(current_step),
    #                     np.transpose(np.concatenate(
    #                         (fake_img, sr_img, hr_img), axis=1), [2, 0, 1]),
    #                     idx)
    #                 avg_psnr += Metrics.calculate_psnr(
    #                     sr_img, hr_img)

    #                 if wandb_logger:
    #                     wandb_logger.log_image(
    #                         f'validation_{idx}', 
    #                         np.concatenate((fake_img, sr_img, hr_img), axis=1)
    #                     )

    #             avg_psnr = avg_psnr / idx
    #             diffusion.set_new_noise_schedule(
    #                 opt['model']['beta_schedule']['train'], schedule_phase='train')
    #             # log
    #             logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
    #             logger_val = logging.getLogger('val')  # validation logger
    #             logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
    #                 current_epoch, current_step, avg_psnr))
    #             # tensorboard logger
    #             tb_logger.add_scalar('psnr', avg_psnr, current_step)

    #             if wandb_logger:
    #                 wandb_logger.log_metrics({
    #                     'validation/val_psnr': avg_psnr,
    #                     'validation/val_step': val_step
    #                 })
    #                 val_step += 1

    #         if current_step % opt['train']['save_checkpoint_freq'] == 0:
    #             logger.info('Saving models and training states.')
    #             diffusion.save_network(current_epoch, current_step)

    #             if wandb_logger and opt['log_wandb_ckpt']:
    #                 wandb_logger.log_checkpoint(current_epoch, current_step)

    #     if wandb_logger:
    #         wandb_logger.log_metrics({'epoch': current_epoch-1})

    # # save model
    # logger.info('End of training.')


In [None]:
# handling validation (not considered for now and maybe even moving to seperate notebook could make sense, has to be thought about)

# else:
#     logger.info('Begin Model Evaluation.')
#     avg_psnr = 0.0
#     avg_ssim = 0.0
#     idx = 0
#     result_path = '{}'.format(opt['path']['results'])
#     os.makedirs(result_path, exist_ok=True)
#     for _,  val_data in enumerate(val_loader):
#         idx += 1
#         diffusion.feed_data(val_data)
#         diffusion.test(continous=True)
#         visuals = diffusion.get_current_visuals()

#         hr_img = Metrics.tensor2img(visuals['HR'])  # uint8
#         lr_img = Metrics.tensor2img(visuals['LR'])  # uint8
#         fake_img = Metrics.tensor2img(visuals['INF'])  # uint8

#         sr_img_mode = 'grid'
#         if sr_img_mode == 'single':
#             # single img series
#             sr_img = visuals['SR']  # uint8
#             sample_num = sr_img.shape[0]
#             for iter in range(0, sample_num):
#                 Metrics.save_img(
#                     Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter))
#         else:
#             # grid img
#             sr_img = Metrics.tensor2img(visuals['SR'])  # uint8
#             Metrics.save_img(
#                 sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx))
#             Metrics.save_img(
#                 Metrics.tensor2img(visuals['SR'][-1]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx))

#         Metrics.save_img(
#             hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
#         Metrics.save_img(
#             lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx))
#         Metrics.save_img(
#             fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))

#         # generation
#         eval_psnr = Metrics.calculate_psnr(Metrics.tensor2img(visuals['SR'][-1]), hr_img)
#         eval_ssim = Metrics.calculate_ssim(Metrics.tensor2img(visuals['SR'][-1]), hr_img)

#         avg_psnr += eval_psnr
#         avg_ssim += eval_ssim

#         if wandb_logger and opt['log_eval']:
#             wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img, eval_psnr, eval_ssim)

#     avg_psnr = avg_psnr / idx
#     avg_ssim = avg_ssim / idx

#     # log
#     logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
#     logger.info('# Validation # SSIM: {:.4e}'.format(avg_ssim))
#     logger_val = logging.getLogger('val')  # validation logger
#     logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}, ssim：{:.4e}'.format(
#         current_epoch, current_step, avg_psnr, avg_ssim))

#     if wandb_logger:
#         if opt['log_eval']:
#             wandb_logger.log_eval_table()
#         wandb_logger.log_metrics({
#             'PSNR': float(avg_psnr),
#             'SSIM': float(avg_ssim)
#         })