## Deep Learning Project - Group 67

This notebook breafly explains how to recreate the results achieved from the experiments described in the paper. Each of the three implemented models - FNO, Diffusion Model, VAE - can be located in their respective directories and run inidvidually by using the commands below to reacreate the results.


**Things to keep in mind:**

The TCV dataset is a large dataset essentially containing 501 512 x 512 images. We pinned memory in the dataloader to cpu and only load the current batch on the GPU. Altough the commands below and the models described in the report are fairly large, we therefore encourage that engough GPU memory is inquired. 

The used packages for this study are safed in requirements.txt, please install these first.

In [2]:
python -m venv .venv_report 
source .venv_report/bin/activate 
pip install -r requirements.txt

SyntaxError: invalid syntax (131280520.py, line 1)

### 1. Training an FNO

Training can be done by using the following command. Please make sure to first 'cd' into the FNO folder before running this script or adjust the command to call python FNO/main.py otherwise. Beware that due to compute limitation, the batch size, model size and or sequence lengths parameters might need to be adjusted. 

Note: The logged images and metrics from training and inference are initially done with wandb. To safe logs locally do not use the --use_wandb flag and the code will resort to a Tensorboard logger that saves locally to the specified dirpath. 

In [None]:
## cd FNO 
python main.py  --mode train   \
            --seq_len 4     \
            --spatial_resolution 128   \  
            --batch_size 4   \  
            --max_epochs 200  \
            --learning_rate 2e-4   \
            --loss_function L2   \ 
            --n_modes 30,30,8   \ 
            --hidden_channels 96   \ 
            --n_layers 4  \ 
            --early_stopping \  
            --patience 20   \ 
            --num_predictions_to_log 1   \ 
            --enable_inference_image_logging   \ 
            --run_inference_after_train   \ 
            --use_wandb   

SyntaxError: invalid syntax (1137130294.py, line 1)

### To run inference change the argument of --checkpoint with the path of the trained model and run this command:

In [3]:
python main.py --mode inference --checkpoint 'dirpath' --batch_size 1 --seq_len 4 --spatial_resolution 128 --use_wandb

SyntaxError: invalid syntax (1285472681.py, line 1)

## 2. Diffusion

Model Training with sample training code. Again beware that the logs and visualizations are automatically done during training, ensure to not use --use_wandb for local tensorboard logging.

In here diffusion_timesteps, diffusion_sampling_timesteps, diffusion_dim, diffusion_dim_mults and diffusion_flash_attn corresponds to the model parameters timesteps, sampling_timesteps, dim, dim_mults and flash_attn respectively. Spatial resolution as above refers to the image resolution.

In [None]:
#cd Diffusion
python main.py --mode train \
 --model "diffusion" \
 --max_epochs 100 \
 --diffusion_timesteps 1000 \
 --diffusion_sampling_timesteps 1000  \
 --dataset_name "DiffusionDataset" \
 --batch_size 4 \
 --spatial_resolution 256 \
 --num_predictions_to_log 1 \
 --diffusion_dim 64 \
 --diffusion_dim_mults "1,2,4,8" \
 --diffusion_flash_attn \
 --use_wandb

### Model testing / inference example:

To sample data for visualization, one has to modify the checkpoint_path in the python file test_diffusion_model.py, and after that run the script test_diffusion_model.py 
This script will then run inference on the provided testloader, saving the predictions in .pt files


In [None]:
python test_diffusion_model.py

### Visualizations

After generating the prediction files with the step above, one can plot the predictions by using the code below. This code snippet can also be found in VisualizeDiffusionModelResult.ipynb

In [None]:
from matplotlib import pyplot as plt
import torch

def visualize_images(input_tensor, target_tensor, prediction_tensor, n=1):
    for i in range(n):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Get images as NumPy arrays
        input_img = input_tensor[i][0].cpu().numpy()
        target_img = target_tensor[i][0].cpu().numpy()
        prediction_img = prediction_tensor[i][0].cpu().numpy()

        # Input images
        im1 = axes[0].imshow(input_img)
        axes[0].set_title('Input Image')
        axes[0].axis('off')

        # Target images
        im2 = axes[1].imshow(target_img)
        axes[1].set_title('Target Image')
        axes[1].axis('off')

        # Predicted images
        im3 = axes[2].imshow(prediction_img)
        axes[2].set_title('Predicted Image')
        axes[2].axis('off')

        # 2. Add individual color bars
        cbar1 = fig.colorbar(im1, ax=axes[0], orientation='vertical', fraction=0.02, pad=0.04)
        cbar1.set_label('Intensity (Input)')

        cbar2 = fig.colorbar(im2, ax=axes[1], orientation='vertical', fraction=0.02, pad=0.04)
        cbar2.set_label('Intensity (Target)')

        cbar3 = fig.colorbar(im3, ax=axes[2], orientation='vertical', fraction=0.02, pad=0.04)
        cbar3.set_label('Intensity (Prediction)')

        plt.show()

In [None]:
x = torch.load(f'x_tensor_test.pt')
y = torch.load(f'y_tensor_test.pt')
y_hat = torch.load(f'y_hat_tensor_test.pt')

In [None]:
visualize_images(x, y, y_hat, n=8)

## 3. VAE

To train the VAE one has to run the following command.

NOTE: The cli arguments of --data-dir and --ouput have to be adjusted first

In [None]:
# cd VAE/conditional_diffusion/
python 34_train_multiscale_vae.py \
    --data-dir /dtu/blackhole/1b/223803/tcv_data \
    --variables n phi \
    --epochs 200 \
    --batch-size 8 \
    --lr 1e-3 \
    --kl-weight 1e-5 \
    --latent-dim 256 \
    --base-channels 32 \
    --loss-type elbow \
    --output /dtu/blackhole/1b/223803/runs/vae_kl_low \
    --use-amp \
    --patience 20

To perform inference and plot visualizations, the following files have to be run in this given order:

- 1_extract_tcv_data.py
- 2_extract_probe_data.py
- 33_model_multiscale_vae_elbow.py
- 34_train_multiscale_vae.py
- 35_inference_multiscale_vae.py
- 36_model_temporal_probe_encoder.py
- 37_train_probe_reconstruction.py
- 38_inference_probe_reconstruction.py