# New sample prediction

This notebook provide an example of end-to-end pipeline for new and real-time predictions

# Table of Contents
<!-- - [New sample prediction](#New-sample-prediction) -->
- [Init](#Init)
- [Download Fits](#Download-Fits)
- [Preprocessing and downsampling](#Preprocessing-and-downsampling)
  - [Instrument correction and exposure normalisation](#Instrument-correction-and-exposure-normalisation)
  - [Scaling and downsampling](#Scaling-and-downsampling)
- [ROI Crops](#ROI-Crops)
- [Predictions](#Predictions)
- [Viz](#Viz)


# Init

In [6]:
from pathlib import Path
from PIL import Image 
import numpy as np

import pandas as pd
import datetime
from tqdm import tqdm 
import os
import requests

# import astropy.time
import astropy.units as u
import sunpy
# from sunpy.net import Fido
# from sunpy.net import attrs as a

# from aiapy.calibrate import degradation
from aiapy.calibrate.util import get_correction_table
from aiapy.calibrate import correct_degradation

import astropy.units as u
from astropy.coordinates import SkyCoord
from sunpy.coordinates import frames
from sunpy.physics.differential_rotation import solar_rotate_coordinate

from functools import reduce

import torch

In [2]:
# Usage example
NRT              = True                                   # use near-real-time images (for the fits downloading)
sample_date      = pd.to_datetime("2024-10-08 23:57:00")
POS_ARCSEC_START = {'13848X1': (32, 112)}   # the ROI id is arbitrary, (x,y) is the ROI center HPC position at (sample_date - 10H)
wavelengths      = ['0193', '0211', '0094'] # wavelength used in input
folder           = f'/Users/greg/Google Drive/Mi unidad/Projects/Forecast/Results/V2V/new_samples'

DOWNLOAD_FITS = True
CORRECT_FITS  = True
PROCESS_PNG   = True
PROCESS_CROPS = True
NUM_SIMUS     = 20
machine       = 'loc'

cadence_minutes = 120

start_date = sample_date - datetime.timedelta(hours= 10)
end_date   = sample_date + datetime.timedelta(hours= 12)

date_range = pd.date_range(start = start_date,
                           end   = end_date,
                           freq  = datetime.timedelta(seconds=60*cadence_minutes)
                           )

if not os.path.exists(folder):
  os.makedirs(folder, exist_ok=True)
  
if machine == 'COLAB':
  from pathlib import Path
  if not Path('/content/drive').exists():
    from google.colab import drive
    drive.mount('/content/drive')
  folder = f'/content/drive/MyDrive/Projects/Forecast/Results/V2V/new_samples'
  

# Download Fits

In [245]:
if DOWNLOAD_FITS:
  for current_date in tqdm(date_range):
      #if current_date > datetime.datetime(2018, 11, 12, 0, 0, 0):
      year = current_date.year
      month = current_date.month
      day = current_date.day
      hour = current_date.hour
      minute = current_date.minute

      for wavelength in wavelengths:
          #if log.loc[current_date,wavelength] == 0 or log.loc[current_date,wavelength] == -1:
          if NRT:
            url = f'http://jsoc.stanford.edu/data/aia/synoptic/nrt/{year}/{month:02d}/{day:02d}/H{hour:02d}00/AIA{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}00_{wavelength}.fits'
          else:
            url = f'http://jsoc.stanford.edu/data/aia/synoptic/{year}/{month:02d}/{day:02d}/H{hour:02d}00/AIA{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{wavelength}.fits'
          # Download the image
          
          path = f'{folder}/fits/aia'
          for subfolder in [f'{year}',f'{month:02d}',f'{day:02d}']:
              path = f'{path}/{subfolder}'
              if not os.path.exists(path):
                  os.makedirs(path, exist_ok = True)
          filename = f'{path}/{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{wavelength}.fits'
          
          if not os.path.exists(filename):
            response = requests.get(url)
        
            if response.status_code == 200:
                # Save the image to a file
                try:
                    
                    with open(filename, 'wb') as f:
                        f.write(response.content)
                    # log.loc[current_date,wavelength] = 1
                    # log.to_csv(fn_log)
                except Exception as e:
                    print(e)
                    # log.loc[current_date,wavelength] = -1
                  # log.to_csv(fn_log)
            else:
              print(response.status_code, ' : ', response.reason)
              print(url)
                # log.loc[current_date,wavelength] = -response.status_code
                # log.to_csv(fn_log)

100%|██████████| 12/12 [00:06<00:00,  1.74it/s]


# Preprocessing and downsampling

## Instrument correction and exposure normalisation

In [246]:
if CORRECT_FITS:
  ctMissingSource = 0
  ctErr = 0
  table = get_correction_table()
  for current_date in tqdm(date_range):
      year = current_date.year
      month = current_date.month
      day = current_date.day
      hour = current_date.hour
      minute = current_date.minute

      for wavelength in wavelengths:
          path_src  = f'{folder}/fits/aia'
          path_dest = f'{folder}/fits/aia_corr'
          for subfolder in [f'{year}',f'{month:02d}',f'{day:02d}']:
              path_src = f'{path_src}/{subfolder}'
          fn_src  = f'{path_src}/{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{wavelength}.fits'
          if os.path.exists(fn_src):
              for subfolder in [f'{year}',f'{month:02d}',f'{day:02d}']:
                  path_dest = f'{path_dest}/{subfolder}'
              os.makedirs(path_dest, exist_ok=True)
              fn_dest = f'{path_dest}/{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{wavelength}.fits'
              if not os.path.exists(fn_dest):
                  try:
                      img = sunpy.map.Map(fn_src)
                      img = correct_degradation(img, correction_table = table)
                      img._data = img.data.astype('float32')
                      img._data = img.data / float(img.meta['exptime']) # img = normalize_exposure(img)
                      img.save(fn_dest)
                  except Exception as e:
                      print(current_date)
                      print(wavelength)
                      print(e)
                      ctErr+=1
                      # log_jpeg.loc[current_date,wavelength] = -2
                          
          else:
              print('WARNING , no : ', fn_src)
              ctMissingSource += 1
  #             log_jpeg.loc[current_date,wavelength] = -1
  # log_jpeg.to_csv(fn_log_jpeg)
  print('Missing sources : ', ctMissingSource/len(wavelengths))
  print('Errors dest     : ', ctErr/len(wavelengths))

100%|██████████| 12/12 [00:01<00:00,  8.34it/s]

Missing sources :  0.0
Errors dest     :  0.0





## Scaling and downsampling

In [247]:
if PROCESS_PNG:
  dim = 1024 # if 1024 no downsampling performed
  stats = {}
  test_spit_date = datetime.datetime(2020,1,1) # end of datasset  date to compute percentile saturation values
  saturation_pctl = 99.9
  saturation = {}
  stats = {}
  thd_std = {}
  print('SATURATION VALUES : ')
  for w in wavelengths:
    # file with images pixel statistics for the whole time-period used in training
    fn_stat = f'/Users/greg/Google Drive/Mi unidad/Projects/Forecast/Datasets/Meta/stats_aia_corr_{w}.csv'
    stats[w] = pd.read_csv(fn_stat)
    stats[w]['timestamp'] = stats[w]['timestamp'].apply(lambda x: datetime.datetime.strptime(x,'%Y-%m-%d %H:%M:%S')) # '%Y/%m/%d/H%H00/
    stats[w] = stats[w].set_index(['timestamp'])
    stats[w] = stats[w].sort_index()
    stats[w] = stats[w][stats[w] .index < test_spit_date]
    maxs = stats[w]['max']
    # maxs = maxs[~np.isnan(maxs)]
    saturation[w] = np.nanpercentile(maxs, saturation_pctl)
    print(w, saturation[w])#, thd_std[w], avg_std, std_std)
    # stats[w].describe()

  for current_date in tqdm(date_range):
      year = current_date.year
      month = current_date.month
      day = current_date.day
      hour = current_date.hour
      minute = current_date.minute
      
      for wavelength in wavelengths:
          path_src  = f'{folder}/fits/aia_corr'
          
          path_dest = f'{folder}/png/{dim}/{wavelength}'
          if not Path(path_dest).exists():
            # os.mkdir(path_dest)
            os.makedirs(path_dest, exist_ok=True)
          
          # if log_jpeg.loc[current_date,wavelength] < 1:
          for subfolder in [f'{year}',f'{month:02d}',f'{day:02d}']:
              path_src = f'{path_src}/{subfolder}'
          fn_src  = f'{path_src}/{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{wavelength}.fits'
          if not os.path.exists(fn_src):
              print('WARNIN  : missing , ',fn_src)
          else:
              for subfolder in [f'{year}',f'{month:02d}',f'{day:02d}']:
                  path_dest = f'{path_dest}/{subfolder}'
                  if not os.path.exists(path_dest):
                      os.mkdir(path_dest)
              fn_dest = f'{path_dest}/{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{wavelength}.png'
              if not Path(fn_dest).exists():
                try:
                    
                    img = sunpy.map.Map(fn_src)
                    
                    if dim != 1024:
                      img = Image.fromarray(img.data)
                      img = img.resize((dim,dim),resample = Image.BILINEAR)#LANCZOS)
                      img = np.array(img)
                    else:
                      img = img.data
                      
                      print(current_date, wavelength, img.mean(), img.std(),  img.min(), img.max() )
                    
                except Exception as e:
                    print(-2, e)
                    continue
                try:
                    sat = saturation[wavelength]
                    img[img<0] = 0
                    img[img>sat] = sat
                    img = np.log(1+img)
                    img = 255 * img / np.log(1+sat)
                    img = np.round(img)
                    img[img>255]=255 # 256 values
                    img = img.astype('uint8')
                except Exception as e:
                    print(-3, e)
                    continue
                try:
                    img = Image.fromarray(img)
                    img.save(fn_dest)
                except Exception as e:
                    print(-4, e)
                    continue  

SATURATION VALUES : 
0193 81139.19094000038
0211 8179.276461000085
0094 6099.484649600037


  0%|          | 0/12 [00:00<?, ?it/s]

2024-10-08 13:57:00 0193 291.4839 412.0286 -1.0143008 9364.596
2024-10-08 13:57:00 0211 136.08289 207.21364 -0.8569284 5496.607
2024-10-08 13:57:00 0094 1.3100001 2.598818 -0.45325023 221.11455


100%|██████████| 12/12 [00:02<00:00,  4.81it/s]


# ROI Crops

In [256]:
if PROCESS_CROPS:
  
  # dimensions from source
  resolution = 2.4 # (arcsec/pix)
  dimPatches = 256
  # downsampling
  downsize   = 128
  
  path_dest = f'{folder}/ar'
  samples_df_path = f'{path_dest}/samples.csv'
  path_src = f'{folder}/png/1024'
  os.makedirs(path_dest, exist_ok=True)
  
  # TO DO : optimise the loop on date to open each fd im onlly once
  for arnum in tqdm(POS_ARCSEC_START.keys()):
    start_pos_hpc = POS_ARCSEC_START[arnum]
    hpc_coords = {date_range[0]:SkyCoord(start_pos_hpc[0] * u.arcsec, start_pos_hpc[1] * u.arcsec, obstime=date_range[0], observer='earth',frame=frames.Helioprojective)}
    hpc_coords.update({t : solar_rotate_coordinate(hpc_coords[date_range[0]], time=t) for t in date_range[1:]})

    sample = pd.DataFrame({'id' : [str(arnum)+'_'+ str(sample_date)], 
                           'HARPNUM' : [arnum], 
                           'T' : [sample_date], 
                           'label' : [arnum[-2:]]},index=[0])
    if not os.path.exists(samples_df_path):
      sample.to_csv(samples_df_path, index = False)
    else:
      samples = pd.read_csv(samples_df_path)
      if sample.loc[0,'id'] not in  samples['id'].values:
        samples = pd.concat([samples,sample],ignore_index=True, axis=0)
        samples.to_csv(samples_df_path, index = False)
    
    for timestamp in hpc_coords.keys():
      
      
    
      center_hpc_x = hpc_coords[timestamp].Tx.value
      center_hpc_y = hpc_coords[timestamp].Ty.value
      
      year = timestamp.year
      month = timestamp.month
      day = timestamp.day
      hour = timestamp.hour
      minute = timestamp.minute
      
      fn_srcs  = [f'{path_src}/{wavelength}/{year}/{month:02d}/{day:02d}/{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{wavelength}.png' for wavelength in wavelengths]
      
      for fn_src in fn_srcs:
        if not os.path.exists(fn_src):
          print('WARNING  : missing , ', fn_src)
          
      arfolder = f'{path_dest}/{reduce(lambda x,y:x+"x"+y, wavelengths)}/{arnum}'
      os.makedirs(arfolder, exist_ok=True)
      
      fn_dest = f'{arfolder}/{arnum}_{year}{month:02d}{day:02d}_{hour:02d}{minute:02d}_{reduce(lambda x,y:x+"x"+y, wavelengths)}.png'
      
      if not os.path.exists(fn_dest):
      
        full_disk =  Image.merge('RGB',[Image.open(fn_src) for fn_src in fn_srcs])
        
        imW, imH = full_disk.size
        
        width = dimPatches # row['width_arcsec_Avg'] / resolution    
        height = dimPatches # row['width_arcsec_Avg'] / resolution
        
        cX = int(np.round(center_hpc_x / resolution + imW/2))
        cY = int(np.round(center_hpc_y / resolution + imH/2))
        
        left = int(np.round(cX - width/2))
        right = int(np.round(cX + width/2))
        top = int(np.round(cY - height/2))
        bottom = int(np.round(cY + height/2))
        
        ar = full_disk.crop((left, top, right, bottom))
        if downsize is not None:
          ar = ar.resize((downsize,downsize),resample = Image.BILINEAR)
        ar.save(fn_dest)
            

100%|██████████| 1/1 [00:03<00:00,  3.66s/it]


# Predictions

In [257]:
path_ars = f'{folder}/ar'
samples_df_path = f'{path_ars}/samples.csv'
samples = pd.read_csv(samples_df_path)
samples

Unnamed: 0,id,HARPNUM,T,label
0,13842X9_2024-10-03 10:21:00,13842X9,2024-10-03 10:21:00,X9
1,13842X9_2024-10-03 10:24:00,13842X9,2024-10-03 10:24:00,X9
2,13848X1_2024-10-08 23:57:00,13848X1,2024-10-08 23:57:00,X1


In [250]:
if NUM_SIMUS is not None:
  from diffusion import Diffusion_cond
  from module import PaletteModelVideo
  from dataloaders import PairedVideosDataset
  from torch.utils.data import DataLoader
  import copy
  import gc
  from torch.cuda.amp import autocast
  import torch
  
  MIXED_PREC = True
  
  pred_sample = samples.iloc[-1:]
  
  proj_folder = f'/content/drive/MyDrive/Projects/Forecast/Results/V2V/corona_v2v_128_2halfD'
  ema_model_path = f"{proj_folder}/models/ema_best.pt"
  model_tag = ema_model_path.split('/')[-2] + '_' + ema_model_path.split('/')[-1].split('.')[0].split('_')[-1]
  
  PRED_FOLDER = f'{folder}/preds'
  os.makedirs(PRED_FOLDER, exists_ok=True)
  
  ds = PairedVideosDataset(
      dataframe            = pred_sample,
      root_dir             = path_ars,
      time_interval_min    = 120,
      num_frames           = 6,
      wavelength           = '0193x0211x0094',
      target_channel_index = 2,
      real_time            = False
  )
  loader = DataLoader(
    ds, 
    batch_size=1,
    shuffle=False,
    pin_memory=True,# pin_memory set to True
    num_workers=0,
    drop_last=False
  )
  device = 'cuda'
  model =  PaletteModelVideo(
      c_in=4,
      c_out=1,
      image_size=128,
      time_dim=256,
      device=device,
      latent=False,
      frames = 6,
      bottleneck_3D = False,
      small = True,
      extra_att = False
  ).to(device)
  diffusion = Diffusion_cond(img_size=128, device=device, img_channel=1, num_frames=6)
  model = copy.deepcopy(model).eval().requires_grad_(False)
  model_checkpoint = torch.load(ema_model_path)
  model.load_state_dict(model_checkpoint['model_state'])
  model.eval()
  with torch.no_grad():
    for sim_index in tqdm(range(NUM_SIMUS)):
      print('\n Starting simu : ',sim_index)
      for i, (video_input, video_target, label, time) in enumerate(tqdm(loader)):
          torch.cuda.empty_cache()
          gc.collect()
          sampled_images = None
          torch.cuda.empty_cache()
          gc.collect()
          
          input = video_input.to(device).float()
          target = video_target.to(device).float()
          # label = label[0]
          # time = time[0]
          id = time[0]
          arnum = [t.split('_')[0] for t in id]
          time = [pd.to_datetime(t.split('_')[-1]).strftime('%Y_%m_%d_%H%M') for t in time]
          pred_path = f"{PRED_FOLDER}/{arnum[0]}/{label[0]}_{arnum[0]}_{time[0]}_sim_{sim_index:03}.npy"
          
          if not os.path.exists(pred_path):
            if MIXED_PREC:
              with autocast():
                sampled_images = diffusion.sample(model, y=input, labels=None)
            else:
              sampled_images = diffusion.sample(model, y=input, labels=None)
            sample = sampled_images[0]
            sample = ds.Normalisation.reverse_transform_exctracted_chanel(sample.cpu()).permute(1, 2, 3, 0).cpu().numpy()
            sample = np.clip((255*sample), 0, 255).astype('uint8')
            pred_path = f"{PRED_FOLDER}/{arnum[0]}"
            if not os.path.exists(pred_path):
              os.makedirs(pred_path)
            pred_path = f"{pred_path}/{label[0]}_{arnum[0]}_{time[0]}_sim_{sim_index:03}.npy"
            # if not os.path.exists(pred_path):
            np.save(pred_path, sample)
            
          if sim_index == 0:
            pred_input = f"{PRED_FOLDER}/{arnum[0]}/{label[0]}_{arnum[0]}_{time[0]}_input.npy"
            pred_target = f"{PRED_FOLDER}/{arnum[0]}/{label[0]}_{arnum[0]}_{time[0]}_target.npy"
            if not os.path.exists(pred_target):
              # input = input[:,2:3]
              sample = input[0,2:3]
              sample = ds.Normalisation.reverse_transform_exctracted_chanel(sample.cpu()).permute(1, 2, 3, 0).cpu().numpy()
              sample = np.clip((255*sample), 0, 255).astype('uint8')
              pred_input = f"{PRED_FOLDER}/{arnum[0]}"
              if not os.path.exists(pred_input):
                os.makedirs(pred_input)
              pred_input = f"{pred_input}/{label[0]}_{arnum[0]}_{time[0]}_input.npy"
              # if not os.path.exists(pred_input):
              np.save(pred_input, sample)
              sample = target[0]
              sample = ds.Normalisation.reverse_transform_exctracted_chanel(sample.cpu()).permute(1, 2, 3, 0).cpu().numpy()
              sample = np.clip((255*sample), 0, 255).astype('uint8')
              pred_target = f"{PRED_FOLDER}/{arnum[0]}"
              if not os.path.exists(pred_target):
                os.makedirs(pred_target)
              pred_target = f"{pred_target}/{label[0]}_{arnum[0]}_{time[0]}_target.npy"
              # if not os.path.exists(pred_target):
              np.save(pred_target, sample)

# Viz

In [4]:
path_ars = f'{folder}/ar'
samples_df_path = f'{path_ars}/samples.csv'
samples = pd.read_csv(samples_df_path)
samples
samples

Unnamed: 0,id,HARPNUM,T,label
0,13842X9_2024-10-03 10:21:00,13842X9,2024-10-03 10:21:00,X9
1,13842X9_2024-10-03 10:24:00,13842X9,2024-10-03 10:24:00,X9
2,13848X1_2024-10-08 23:57:00,13848X1,2024-10-08 23:57:00,X1


In [None]:
MAKE_GIFS = True

if MAKE_GIFS:
  descale = False
  sample = samples.iloc[-1]
  model_tag = 'ema_best'
  pred_paths = f"{folder}/preds/{sample['HARPNUM']}/{model_tag}"
  label = sample['label']
  id = sample['HARPNUM']+'_'+pd.to_datetime(sample['T']).strftime('%Y_%m_%d_%H%M')
  
  import matplotlib.pyplot as plt
  import matplotlib.animation as animation
  from glob import glob
  def descaling(scaled_image, max = 6099):
    max = np.log(1+max)
    image = np.exp(max * scaled_image / 255) - 1
    return image
  def save_and_plot_v2v(video_input, video_target, video_prediction, save_path, title="", fps=4, target_channel = 2,
                        vmin = None,
                        vmax = None
                        ):
    """
    Creates and saves an animation showing three videos side by side.

    Args:
        input: Numpy array for the true input video of shape (frames, height, width, channels)
        video_target: Numpy array for the ground truth of shape (frames, height, width, channels)
        ema_samp: Numpy array for the EMA sampled video of shape (frames, height, width, channels)
        save_path: Where to save the video (GIF)
        title: Optional title for the entire video.
        fps: Frames per second for the output video.
    """
    if target_channel:
      video_input = video_input[:,:,:,target_channel:target_channel+1]
    # Number of frames
    frames = video_input.shape[0]
    images = []

    # Create a figure with three subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))

    # Set the initial frame in all subplots
    im1 = ax1.imshow(video_input[0], origin='lower', vmin=vmin, vmax=vmax)
    ax1.set_title('12 previous hours')

    im2 = ax2.imshow(video_target[0], origin='lower', vmin=vmin, vmax=vmax)
    ax2.set_title('True next 12h')

    im3 = ax3.imshow(video_prediction[0], origin='lower', vmin=vmin, vmax=vmax)
    ax3.set_title('Predicted next 12h')

    # Add a big title in the middle of all subplots
    fig.suptitle(title)

    def update(frame_idx):
        """Updates the plot for the given frame index."""
        im1.set_array(video_input[frame_idx])
        im2.set_array(video_target[frame_idx])
        im3.set_array(video_prediction[frame_idx])
        return [im1, im2, im3]

    def update_gif(frame_idx):
      # Convert the canvas to a PIL image and add it to the list
      im1.set_array(video_input[frame_idx])
      im2.set_array(video_target[frame_idx])
      im3.set_array(video_prediction[frame_idx])
      fig.canvas.draw()
      img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
      img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
      images.append(Image.fromarray(img))

    # Create the animation object
    ani = animation.FuncAnimation(fig, update, frames=frames, interval=1000 // fps)

    # Save the animation as a GIF or video
    if save_path is not None:
      # Save the animation as a GIF or video
      # ani.save(save_path, writer='imagemagick')
      for i in range(frames):
          update_gif(i)
      images[0].save(save_path, save_all=True, append_images=images[1:], duration=1000 // fps, loop=0)
    plt.tight_layout()
    plt.show()
    return ani


  if descale:
    input = torch.from_numpy(descaling(np.transpose(np.load(f'{pred_paths}/{label}_{id}_input.npy'), (3, 0, 1, 2)).astype('float32')))
    target = torch.from_numpy(descaling(np.transpose(np.load(f'{pred_paths}/{label}_{id}_target.npy'), (3, 0, 1, 2)).astype('float32')))
  else:
    input = torch.from_numpy((np.transpose(np.load(f'{pred_paths}/{label}_{id}_input.npy'), (3, 0, 1, 2)).astype('float32')))
    target = torch.from_numpy((np.transpose(np.load(f'{pred_paths}/{label}_{id}_target.npy'), (3, 0, 1, 2)).astype('float32')))
  
  for simu_path in glob(f"{pred_paths}/*sim*.npy"):
    save_path  = f"{pred_paths}/gifs"
    if descale:
      save_path += '_descaled'
    sim_num = simu_path.split('sim_')[-1].split('.npy')[0]
    save_path_full=f"{save_path}/simu_{sim_num}.gif"
    if not  os.path.exists(save_path_full):
      if descale:
        simu = torch.from_numpy(descaling(np.transpose(np.load(simu_path), (3, 0, 1, 2)).astype('float32'))) 
      else:
        simu = torch.from_numpy((np.transpose(np.load(simu_path), (3, 0, 1, 2)).astype('float32'))) 
      os.makedirs(save_path, exist_ok=True)
      
      vmin=np.min([input.min(),target.min()])
      vmax=np.max([input.max(),target.max()])
      save_and_plot_v2v(input[0], target[0], simu[0], 
                        save_path=save_path_full, 
                        title=f"Label : {label} , AR : {arnum}, Time : {sample['T']}", 
                        fps=3, 
                        target_channel = None,
                        vmin=vmin,
                        vmax=vmax
                        )
      print(f'SIMU {sim_num} : ')
      print(input.max(),target.max(),simu.max())
      print(input.mean(),target.mean(),simu.mean())