In [None]:
import datetime
import itertools
import math
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import pdb
import sys
import xarray as xr
from tqdm import tqdm
from PIL import Image

sys.path.append('/home/quahb/caipi_denoising/src')

os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/home/quahb/.conda/pkgs/cuda-nvcc-12.1.105-0'
os.environ['TF_GPU_ALLOCATOR']='cuda_malloc_async'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow import keras
from modeling.DiffusionModel import build_model, DiffusionModel
from preparation.prepare_tf_dataset import np_to_tfdataset
from preparation.data_io import load_dataset
from preparation.preprocessing_pipeline import fourier_transform, inverse_fourier_transform, low_pass_filter, rescale_magnitude
from utils.dct import dct2, idct2
from utils.GaussianDiffusion import GaussianDiffusion
from utils.vizualization_tools import plot2, plot4, plot_slices, plot_patches

%load_ext autoreload
%autoreload 2

# Hyperparams

In [None]:
train_batch_size = 16
inference_batch_size = 16
num_epochs = 501
total_timesteps = 1000
norm_groups = 8  # Number of groups used in GroupNormalization layer
learning_rate = 2e-4
image_embedding = True

img_size = 384 
img_channels = 1

first_conv_channels = 64
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2  # Number of residual blocks

# Train Network

## 1. Dataset

In [None]:
images = load_dataset('/home/quahb/caipi_denoising/data/datasets/unaccelerated/full_magnitude/images/', 2, 'full')
images = images[:8000]
print(images.shape)

tf_images = np_to_tfdataset(images, batch_size=train_batch_size)
train_ds = tf_images

train_ds = (
    tf_images.prefetch(tf.data.AUTOTUNE)
)

## 2. Build Network

In [None]:
load_epoch = 0
diffusion_model_name = f'diffusion_models/diffusion_ep{load_epoch}.hd5'
ema_diffusion_model_name = f'diffusion_models/ema_diffusion_ep{load_epoch}.hd5'

# Build the unet model
gpus = ['/GPU:0', '/GPU:1', '/GPU:2', '/GPU:3']
strategy = tf.distribute.MirroredStrategy(devices=gpus)
with strategy.scope():
    network = build_model(
        img_size=img_size,
        img_channels=img_channels,
        widths=widths,
        has_attention=has_attention,
        num_res_blocks=num_res_blocks,
        norm_groups=norm_groups,
        activation_fn=keras.activations.swish,
        image_embedding=image_embedding,
    )
    ema_network = build_model(
        img_size=img_size,
        img_channels=img_channels,
        widths=widths,
        has_attention=has_attention,
        num_res_blocks=num_res_blocks,
        norm_groups=norm_groups,
        activation_fn=keras.activations.swish,
        image_embedding=image_embedding,
    )

    if load_epoch == 0:
        ema_network.set_weights(network.get_weights())  # Initially the weights are the same
    else:
        network.load_weights(diffusion_model_name)
        ema_network.load_weights(ema_diffusion_model_name)


    # Get an instance of the Gaussian Diffusion utilities
    gdf_util = GaussianDiffusion(timesteps=total_timesteps)

    # Get the model
    model = DiffusionModel(
        network=network,
        ema_network=ema_network,
        gdf_util=gdf_util,
        timesteps=total_timesteps,
        image_embedding=image_embedding
    )
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
    )
    
model.network.summary()

In [None]:
s_date = datetime.date.today()

model.fit(
    train_ds,
    epochs=num_epochs,
    batch_size=train_batch_size,
    initial_epoch=load_epoch,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        keras.callbacks.LambdaCallback(on_epoch_end=model.save_model), 
    ],
)

In [None]:
model.plot_images(num_rows=5, num_cols=4)