#### Deep Learning Project - Group 67
########################################################################################
########################################################################################

##### This script explains the usage of each of the models
##### Each of our models - FNO, Diffusion and VAE, are placed in their respective folders 

########################################################################################
########################################################################################

#### 0. Prerequisites
########################################################################################
##### Place all these folders into the HPC
##### Have all the necessarily libraries installed

########################################################################################
########################################################################################

#### 1. FNO
########################################################################################

### Commands to run on the HPC terminal:
#### Note: The # are explanations as to what the argument is for

########################################################################################

#### Model Training with sample training code:
###### Note: Visualizations are automatically done during training. Plots are shown in wandb.

###### cd FNO
###### python main.py  --mode train   \
######                --seq_len 4     \
######                --spatial_resolution 128   \  # train img resolution
######                --batch_size 4   \  
######                --max_epochs 200  \
######                --learning_rate 2e-4   \
######                --loss_function L2   \ # L2 Loss function, essentially a MSE Loss, could also use H1 Loss which is just a weighted combination of L2 and gradient losses
######                --n_modes 30,30,8   \ # number of fourier modes in each dimension, (x, y, z) z being time therefore we advise using same or smaller value then seq_len
######                --hidden_channels 96   \ # Number of hidden channels used in the linear layers of the FNO layers 
######                --n_layers 4  \ # number of FNO layers
######                --early_stopping \  
######                --patience 20   \ # early stopping patience
######                --num_predictions_to_log 1   \ # number of inference predictions to log during training
######                --enable_inference_image_logging   \ # enable logging of inference images during training
######                --run_inference_after_train   \ # run inference after training
######                --use_wandb   \ # uses wandb for logging if not set will use tensorboard
######                --wandb_project plasma-simulation
######                # --timestep_to_show 70 -> want to log a specific prediction timestep during inference on test loader, used 70 for late plasma dynamics
######                # --ablation_study -> run ablation study of training multiple models on different seq_len to compare performance

########################################################################################
### Model testing / inference example:
###### python main.py --mode inference --checkpoint 'dirpath' --batch_size 1 --seq_len 4 --spatial_resolution 128 --use_wandb

########################################################################################
########################################################################################
#### 2. Diffusion

########################################################################################

#### Model Training with sample training code:
###### Note: Loss Visualizations are automatically done during training. Plots are shown in wandb.

###### cd Diffusion
###### python main.py --mode train --model "diffusion" --max_epochs 100 --diffusion_timesteps 1000 --diffusion_sampling_timesteps 1000  --dataset_name "DiffusionDataset" --batch_size 4 --spatial_resolution 256 --num_predictions_to_log 1 --diffusion_dim 64 --diffusion_dim_mults "1,2,4,8" --diffusion_flash_attn --use_wandb

###### Where diffusion_timesteps, diffusion_sampling_timesteps, diffusion_dim, diffusion_dim_mults and diffusion_flash_attn corresponds to the model parameters timesteps, sampling_timesteps, dim, dim_mults and flash_attn respectively. spatial_resolution is an integer representing the width and height of the input and output frames.

#####################################################################################
### Model testing / inference example:
###### To get sample data for visualization, modify checkpoint_path in python test_diffusion_model.py, then do:
###### python test_diffusion_model.py

########################################################################################

### Visualizations

###### The files generated can then be visualized by the code in Report/VisualizeDiffusionModelResult.ipynb, which is copied over here:

In [None]:
from matplotlib import pyplot as plt
import torch
import numpy as np

In [None]:
def visualize_images(input_tensor, target_tensor, prediction_tensor, n=1):
    for i in range(n):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Get images as NumPy arrays
        input_img = input_tensor[i][0].cpu().numpy()
        target_img = target_tensor[i][0].cpu().numpy()
        prediction_img = prediction_tensor[i][0].cpu().numpy()

        # 1. Plots
        # Input images
        im1 = axes[0].imshow(input_img)
        axes[0].set_title('Input Image')
        axes[0].axis('off')

        # Target images
        im2 = axes[1].imshow(target_img)
        axes[1].set_title('Target Image')
        axes[1].axis('off')

        # Predicted images
        im3 = axes[2].imshow(prediction_img)
        axes[2].set_title('Predicted Image')
        axes[2].axis('off')

        # 2. Add individual color bars
        cbar1 = fig.colorbar(im1, ax=axes[0], orientation='vertical', fraction=0.02, pad=0.04)
        cbar1.set_label('Intensity (Input)')

        cbar2 = fig.colorbar(im2, ax=axes[1], orientation='vertical', fraction=0.02, pad=0.04)
        cbar2.set_label('Intensity (Target)')

        cbar3 = fig.colorbar(im3, ax=axes[2], orientation='vertical', fraction=0.02, pad=0.04)
        cbar3.set_label('Intensity (Prediction)')

        plt.show()

In [None]:
x = torch.load(f'x_tensor_test.pt')
y = torch.load(f'y_tensor_test.pt')
y_hat = torch.load(f'y_hat_tensor_test.pt')

# Check if shape is correct
print(x.shape) 
print(y.shape)
print(y_hat.shape)

In [None]:
visualize_images(x, y, y_hat, n=8)

########################################################################################
########################################################################################
#### 3. VAE

########################################################################################

### Note: Need access to the referenced paths in HPC to run the following commands:

########################################################################################

#### Model Training code:

###### cd VAE/conditional_diffusion/
###### python 34_train_multiscale_vae.py \
######	--data-dir /dtu/blackhole/1b/223803/tcv_data \
######	--variables n phi \
######	--epochs 200 \
######	--batch-size 8 \
######	--lr 1e-3 \
######	--kl-weight 1e-5 \
######	--latent-dim 256 \
######	--base-channels 32 \
######	--loss-type elbow \
######	--output /dtu/blackhole/1b/223803/runs/vae_kl_low \
######	--use-amp \
######	--patience 20

#####################################################################################
### Model inference & data visualization:
##### To peform inference & data visualization, run the files in the following order:
- 1_extract_tcv_data.py
- 2_extract_probe_data.py
- 33_model_multiscale_vae_elbow.py
- 34_train_multiscale_vae.py
- 35_inference_multiscale_vae.py
- 36_model_temporal_probe_encoder.py
- 37_train_probe_reconstruction.py
- 38_inference_probe_reconstruction.py