# Atelier 1 : Denoising parfaitement supervisé

In [None]:
# Pour commencer : aller dans éxecution/modifier le type d'exécution et vérifier
# que CPU est bien coché (on n'a pas besoin de plus pour l'instant)

# Imports des bibliothèques utiles
# pour l'IA
import torch
# pour les maths
import numpy as np
# pour afficher des images et des courbes
import matplotlib.pyplot as plt

In [None]:
! git clone https://github.com/nanopiero/spatialisation_grele

Cloning into 'spatialisation_grele'...
remote: Enumerating objects: 3, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 3 (delta 0), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (3/3), done.


In [None]:
! ls spatialisation_grele

apprentissage.ipynb  presentation.ipynb


## A. Découverte du problème

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Le 25/04/2024
@author: lepetit
#fonctions utiles pour l'atelier PREAC
"""

import torch
import numpy as np
from random import randint
import matplotlib.pyplot as plt

import torch
import numpy as np
from random import randint
import matplotlib.pyplot as plt
import os


In [65]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Le 03/06/2024
@author: lepetit
# fonctions utiles pour la génération
# de données à fusionner
"""
from random import randint
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import copy

##############################
########## with JIT ##########
##############################
from numba import jit
from numpy.random import randint


@jit(nopython=True)
def pseudo_meshgrid(size):
  b = np.arange(0, size).repeat(size).reshape(1,size,size)
  a = np.transpose(b, (0,2,1))
  return  a.astype(np.float32), b.astype(np.float32)

@jit(nopython=True)
def generate_cell_and_hail_characteristics(nsteps, size, pseudo_size, centered,
                                           relative_advection_speed,
                                           smajor_axis_k=None,
                                           eccentricity_k=None,
                                           theta_increment=None,
                                           intensity=None
                                           ):
    # Choose k
    if centered:
      k = np.random.randint(nsteps//4, nsteps-nsteps//4)
    else:
      k = np.random.randint(2, nsteps-2)

    # Generate the kth major_axis, minor_axis,  center, rotation, focus, and radius
    if smajor_axis_k is None:
      smajor_axis_k = np.random.uniform(5, 20)

    if centered:
      center_k = np.array([size // 2, size //2 ]).astype(np.float32)
    else:
      center_k = (3*smajor_axis_k + (size - 3*smajor_axis_k)*np.random.random(2)).astype(np.float32)

    if eccentricity_k is None:
      eccentricity_k = 0.2 + 0.8 * np.random.rand()

    theta_k = np.pi * np.random.rand()

    # Generate advection speed and radius increment
    if centered:
      relative_advection_speed = 2*np.random.normal(0, 3, 2).astype(np.float32)
      advection_speed = np.zeros(2).astype(np.float32)
    else:
      advection_speed = (2*np.random.normal(0, 3, 2) - relative_advection_speed).astype(np.float32)

    smajor_axis_increment = np.random.normal(0, 1/nsteps)
    eccentricity_increment = 2/nsteps * np.random.rand()

    if theta_increment is None:
      theta_increment = np.random.normal(0, np.pi/nsteps)

    # Fill centers and radii arrays
    arange_nsteps = np.arange(nsteps).astype(np.float32)
    abs_centers = center_k[0] + (arange_nsteps - k) * advection_speed[0]
    ord_centers = center_k[1] + (arange_nsteps - k) * advection_speed[1]

    smajor_axis = smajor_axis_k + (arange_nsteps - k) * smajor_axis_increment
    smajor_axis[smajor_axis <= 0] = 0.
    eccentricity = eccentricity_k  +  (arange_nsteps - k) * eccentricity_increment
    eccentricity[eccentricity <= 0] = -1*eccentricity[eccentricity <= 0]
    eccentricity[eccentricity >= 0.9] = 0.9
    theta = theta_k  +  (np.arange(nsteps).astype(np.float32) - k) * theta_increment

    # Get intensity of the cell and hail characteristics:
    ratio_radius = 0.3
    if intensity is None:
      intensity = np.random.uniform(0.3,0.8)

    radius = ratio_radius * eccentricity * smajor_axis
    hail_steps = (intensity > 0.5) * (theta_increment <= 0) * (radius > 2) * (radius < 4)
    hail_size = 10/(size**2) * np.pi * smajor_axis**2 * np.sqrt(1 - eccentricity**2) # hail size prop to the area

    return abs_centers, ord_centers, smajor_axis, eccentricity, \
           theta, theta_increment, intensity, hail_steps, hail_size, radius


@jit(nopython=True)
def closest_nonzero_index(x: np.array) -> int:
    # Ensure x is a 1D tensor

    D = x.shape[0]
    mid_index = D // 2

    # Find all non-zero indices
    nonzero_indices = np.nonzero(x)[0]

    # If there are no non-zero elements, return -1 or handle as needed
    if len(nonzero_indices) == 0:
        return -1

    # Find the non-zero index closest to D//2
    print(nonzero_indices)
    distances = np.abs(nonzero_indices - mid_index)
    closest_index = nonzero_indices[np.argmin(distances)].item()
    return closest_index

import numpy as np
from numba import njit



#@njit
def select_random_nonzero_pixel(x):
    # Find all non-zero indices
    non_zero_indices = np.argwhere(x != 0)

    # If there are no non-zero elements, return an array of zeros
    if len(non_zero_indices) == 0:
        return np.zeros_like(x)

    # Select a random index from the non-zero indices
    random_index = non_zero_indices[np.random.randint(len(non_zero_indices))]

    return random_index

#@jit(nopython=True)
def ground_truth_to_reports(two_circles, centered, freq_reports, toss=False):
  if centered and toss: # at least one report
    when_hail = np.sum(two_circles, axis=(1,2)) > 0
    where_report = closest_nonzero_index(when_hail)
    random_index = select_random_nonzero_pixel(two_circles[where_report])
    value = two_circles[where_report, random_index[0], random_index[1]].item()

  thresh = 0.
  reports = np.random.binomial(1, freq_reports, two_circles.shape).astype(np.float32)
  reports *= two_circles

  if centered and toss:
    two_circles[where_report, random_index[0], random_index[1]] = value

  return reports.sum(axis=0)



#@jit(nopython=True)
def simu_moving_ellipse(image, reports, a, b,
                        pseudo_size,
                        stratification=None,
                        centered=False, relative_advection_speed=np.zeros(2),
                        add_target=True,
                        freq_occurrence=0.5,
                        freq_reports=0.5):

  nsteps, nchannels, size, _ = image.shape

  if stratification is None:
    abs_centers, ord_centers, smajor_axis, eccentricity, \
      theta, theta_increment, intensity, hail_steps, hail_size, radius = \
      generate_cell_and_hail_characteristics(nsteps, size, pseudo_size, centered, relative_advection_speed)

  elif stratification == 'occurrence':
    toss = (np.random.rand(1) < freq_occurrence).item()
    if toss:
      ratio_radius = 0.3
      hail_step_k = False
      while not hail_step_k:
        smajor_axis_k = np.random.uniform(5, 20)
        eccentricity_k = 0.2 + 0.8 * np.random.rand()
        enhancer_hail_k = ratio_radius * eccentricity_k * smajor_axis_k
        hail_step_k = (enhancer_hail_k < 4) * (enhancer_hail_k > 2)
      theta_increment = - np.abs(np.random.normal(0, np.pi/nsteps))
      intensity = np.random.uniform(0.5, 0.8)

      abs_centers, ord_centers, smajor_axis, eccentricity, \
        theta, theta_increment, intensity, hail_steps, hail_size, radius = \
        generate_cell_and_hail_characteristics(nsteps, size, pseudo_size, centered,
                                               relative_advection_speed,
                                               smajor_axis_k=smajor_axis_k,
                                               eccentricity_k=eccentricity_k,
                                               theta_increment=theta_increment,
                                               intensity=intensity
                                               )
    else:
      hail_steps_sum = 1.
      while hail_steps_sum != 0.:
        abs_centers, ord_centers, smajor_axis, eccentricity, \
          theta, theta_increment, intensity, hail_steps, hail_size, radius = \
          generate_cell_and_hail_characteristics(nsteps, size, pseudo_size, centered,
                                                relative_advection_speed)
        hail_steps_sum = hail_steps.sum()


    # print(toss, ' steps :', hail_steps, ' sizes :', hail_size, 'theta inc', theta_increment, 'intensity', intensity, 'radius', radius)

  elif stratification == 'size':
    pass

  # Make the cells:
  delta_abs_interfocus = eccentricity * smajor_axis * np.cos(theta)
  delta_ord_interfocus = eccentricity * smajor_axis * np.sin(theta)

  abs_focus1 = abs_centers + delta_abs_interfocus
  ord_focus1 = ord_centers + delta_ord_interfocus
  abs_focus2 = abs_centers - delta_abs_interfocus
  ord_focus2 = ord_centers - delta_ord_interfocus

  square_distances_to_focus1 = (a - abs_focus1.reshape((nsteps, 1, 1)))**2 + \
                         (b - ord_focus1.reshape((nsteps, 1, 1)))**2
  square_distances_to_focus2 = (a - abs_focus2.reshape((nsteps, 1, 1)))**2 + \
                         (b - ord_focus2.reshape((nsteps, 1, 1)))**2

  sum_distances = np.sqrt(square_distances_to_focus1) + np.sqrt(square_distances_to_focus2)
  ellipses =  1. * (sum_distances < 1.25*2*smajor_axis.reshape(nsteps, 1, 1))

  # apply a random intensity
  ellipses *= intensity
  image[:,0] = image[:,0] + ellipses

  # Make ground truth
  # print('add_target ', add_target)

  if (intensity > 0.5) and (theta_increment <= 0) and add_target:

    radius = radius.reshape((nsteps, 1, 1))
    den = 1/radius**2
    two_circles = (radius**2 - square_distances_to_focus1) * den * (square_distances_to_focus1 < radius**2) \
               + (radius**2 - square_distances_to_focus2) * den * (square_distances_to_focus2 < radius**2)
    two_circles *= hail_steps.reshape((nsteps, 1, 1))
    # print('add_target and max ', add_target, two_circles.max())
    # hail size prop to the area :
    two_circles *= hail_size.reshape((nsteps, 1, 1))
    reports += ground_truth_to_reports(two_circles, centered, freq_reports, toss=True)


  return image, reports, relative_advection_speed



@jit(nopython=True)
def resize_channel(channel, new_size):
    x = np.linspace(0, 1, channel.shape[0])
    y = np.linspace(0, 1, channel.shape[1])
    x_new = np.linspace(0, 1, new_size)
    y_new = np.linspace(0, 1, new_size)
    return np.interp(x_new[:, None] + y_new[None, :], x, np.interp(y_new, y, channel))

#@jit(nopython=True)
def spatialized_gt(ndiscs=5, size=64, pseudo_size=None,
                   nsteps=60, stratification=None,
                   centered=False, freq_occurrence=0.5,
                  freq_reports=0.5):

  if pseudo_size is None:
    pseudo_size = size

  image = np.zeros((nsteps, 2, size, size)).astype(np.float32)
  reports = np.zeros((2, size, size)).astype(np.float32)
  a, b = pseudo_meshgrid(size)

  if centered:
    image, reports, relative_advection_speed0 = simu_moving_ellipse(image, reports, a, b,
                                                           pseudo_size,
                                                           stratification,
                                                           centered=True,
                                                           freq_occurrence=freq_occurrence,
                                                           freq_reports=freq_reports
                                                           )
    for i in range(ndiscs - 1):
      image, reports, _ = simu_moving_ellipse(image, reports, a, b,
                                     pseudo_size,
                                     centered=False,
                                     relative_advection_speed=relative_advection_speed0,
                                     add_target=False)
  else:
    image, reports, relative_advection_speed0 = simu_moving_ellipse(image, reports, a, b,
                                                           pseudo_size,
                                                           stratification="occurrence",
                                                           centered=False,
                                                           freq_occurrence=freq_occurrence,
                                                           freq_reports=freq_reports
                                                           )
    for i in range(ndiscs - 1):
      image, reports, _ = simu_moving_ellipse(image, reports, a, b,
                                              pseudo_size,
                                              stratification,
                                              freq_occurrence=freq_occurrence,
                                              freq_reports=freq_reports)
  return image, reports





##############################
########## Datasets ##########
##############################

from torch.utils.data import Dataset

class HailDataset(Dataset):
    def __init__(self, length_dataset=6400, centered=True, freq_reports=0.01):
        """
        Args:
              I need a pytorch dataset that will simply embed two numpy function that generates random tensors.
              These functions, called spatialized_gt and create_cmls_filter are @jit decorated.
        """
        self.length_dataset = length_dataset
        self.centered = centered
        self.ndiscs = 4
        self.pseudo_size = 172
        self.nsteps = 12

        if centered:
          self.stratification ='occurrence'
          self.size_image = 64
        else:
          self.stratification = None
          self.size_image = 172


    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        image, reports = spatialized_gt(ndiscs=self.ndiscs,
                               size=self.size_image,
                               pseudo_size=self.pseudo_size,
                               nsteps=self.nsteps,
                               stratification=self.stratification,
                               centered=self.centered,
                               freq_reports=self.freq_reports
                               )

        # image = spatialized_gt(ndiscs=4, size=64, pseudo_size=172, nsteps=12, stratification='occurrence', centered=True)
        # image = spatialized_gt(ndiscs=4, size=172, pseudo_size=172, nsteps=12, stratification=None, centered=False)

        return image





##############################
########## on GPU   ##########
##############################


def generate_indices_rows_and_columns(images, npoints):
  bs, nsteps, S, _ = images.shape
  weights = torch.ones(S**2).expand(bs, -1).to(images.device)
  indices = torch.multinomial(weights, num_samples=npoints, replacement=False) #.to(images.device)

  # Calculate coordinates from indices
  rows = indices // S
  cols = indices % S

  # Gather the values from these indices for all images
  indices = indices.unsqueeze(dim=1).repeat([1,nsteps,1])
  return indices, rows, cols


def indices_to_sampled_values(images, indices):
  bs, nsteps, S, _ = images.shape
  flat_images = images.view(bs, nsteps, S * S)

  # Gather the values from these indices for all images
  sampled_values = torch.gather(flat_images, 2, indices)
  return sampled_values


def get_point_measurements(rows, cols, sampled_values, S=64):
  # Normalize coordinates to be between 0 and 1
  ys = (1 - rows.float()/S) - 1/(2*S)
  xs = cols.float()/S + 1/(2*S)

  # Stack the normalized coordinates with the values
  point_measurements = torch.cat((xs.unsqueeze(1),
                                  ys.unsqueeze(1),
                                  sampled_values), dim=1)
  return point_measurements


def point_gt(images, ind_row_col_sampval=None, npoints=10, use_fcn=False, split=None): # nb_pluvios_ Split: (n0,n1,n2,..., nr). rq : n_points = Sum ni
  bs, nsteps, S, _ = images.shape

  if ind_row_col_sampval is None:

      indices, rows, cols = generate_indices_rows_and_columns(images, npoints)
      sampled_values = indices_to_sampled_values(images, indices)

  else:
      indices, rows, cols, sampled_values = ind_row_col_sampval

  if split is None:
    point_measurements = get_point_measurements(rows, cols, sampled_values, S)

    if not use_fcn:
      return point_measurements, None, (indices, rows, cols)

    else:
      # Difference with point_gt:
      point_measurements_fcn = -0.1 * torch.ones(images.numel(), device=images.device)
      indices_batch = torch.arange(bs).repeat(60)
      # indice du premier élément de la i ème image pour le premier time step dans images.flatten()
      idx_i000=(torch.arange(bs, device = images.device) * nsteps).view(bs,1).expand(bs,nsteps)
      # indices du premier élément de la i ème image pour le premier time step j dans images.flatten()
      idx_ij00=idx_i000 + torch.arange(nsteps, device = images.device).view(1,nsteps).expand(bs,nsteps)
      # indices à conserver :
      idx_ijkl = S**2 * idx_ij00.unsqueeze(-1) + indices
      point_measurements_fcn[idx_ijkl.flatten()] = sampled_values.flatten()

      point_measurements_fcn = point_measurements_fcn.view(bs, nsteps, S, S)

      return point_measurements, point_measurements_fcn, (indices, rows, cols)

  # splitting
  else:
    pos = 0
    splitted_point_measurements = []
    for np in split:
      point_measurements = get_point_measurements(rows[:, pos:pos + np],
                                                  cols[:, pos:pos + np],
                                                  sampled_values[:, :, pos:pos + np],
                                                  S)

      splitted_point_measurements.append((point_measurements,
                                          None,
                                          (indices[:, :, pos:pos + np], rows[:,pos:pos + np], cols[:, pos:pos + np])))
      pos += np

    if not use_fcn :
        return splitted_point_measurements

    else :
      pos = 0
      splitted_point_measurements_fcn = []

      for i, np in enumerate(split):
        split_indices = indices[:, :, pos:pos + np]

        # Difference with point_gt:
        point_measurements_fcn = -0.1 * torch.ones(images.numel(), device=images.device)
        indices_batch = torch.arange(bs).repeat(60)
        # indice du premier élément de la i ème image pour le premier time step dans images.flatten()
        idx_i000=(torch.arange(bs, device = images.device) * nsteps).view(bs,1).expand(bs,nsteps)
        # indices du premier élément de la i ème image pour le premier time step j dans images.flatten()
        idx_ij00=idx_i000 + torch.arange(nsteps, device = images.device).view(1,nsteps).expand(bs,nsteps)
        # indices à conserver :
        idx_ijkl = S**2 * idx_ij00.unsqueeze(-1) + split_indices
        point_measurements_fcn[idx_ijkl.flatten()] = sampled_values[:, :, pos:pos + np].flatten()

        splitted_point_measurements_fcn.append((splitted_point_measurements[i][0], point_measurements_fcn.view(bs, nsteps, S, S),
                                                (split_indices, rows[:,pos:pos + np], cols[:, pos:pos + np])))
        pos += np

      return splitted_point_measurements_fcn


In [4]:
import matplotlib.gridspec as gridspec

def plot_images(images, ground_truth, reports):
    # Set up the figure with GridSpec
    fig = plt.figure(figsize=(18, 20))
    gs = gridspec.GridSpec(6, 5, width_ratios=[1, 1, 1, 1, 2])  # Last column twice as wide

    # Manually create axes array for uniform handling as before
    axs = [fig.add_subplot(gs[i, j]) for i in range(6) for j in range(5)]

    # Hide all primary spines and ticks
    for ax in axs:
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(axis='both', which='both', left=False, right=False, top=False, bottom=False, labelleft=False, labelbottom=False)

    # Image and noisy image plots
    for i in range(3):
        image_indices = [4*i, 4*i+1, 4*i+2, 4*i+3]
        for j in range(4):
            ax = axs[2*i*5 + j]
            img = images[image_indices[j]]
            # img_normalized = (img - np.min(img)) / (np.max(img) - np.min(img) + 0.000001)
            ax.imshow(img, cmap='gray', aspect=1, vmin=0, vmax=1)
            ax.axis('off')
            ax = axs[(2*i + 1)*5  + j]
            img = ground_truth[image_indices[j]]
            # img_normalized = (img - np.min(img)) / (np.max(img) - np.min(img) + 0.000001)
            ax.imshow(img, cmap='gray', aspect=1, vmin=0, vmax=1)
            ax.axis('off')
    """
    # Point and Segment measurements plots
    for row in range(12):
        ax_main = axs[row * 7 + 6]  # Last column in each row
        if row < 2:  # First two rows for point measurements
            for idx in range(3) if row == 0 else range(2):
                ax = ax_main.inset_axes([0, 1 - (idx+1)/3, 1, 1/3])
                ax.plot(point_measurements[2:, idx + row*3], marker='.', markevery=(4, 5), markeredgewidth=2, markeredgecolor='black')
                label = f"Pluvio {idx+1 + row*3}"
                ax.set_ylim([-0.1, 1.5])
                coord1 = f"x={point_measurements[0, idx + row*3]:.2f}"  # First coordinate on a new line
                coord2 = f"y={point_measurements[1, idx + row*3]:.2f}"  # Second coordinate on another new line
                full_label = f"{label}\n{coord1}\n{coord2}"  # Combine into one string with two newlines
                ax.set_ylabel(full_label, rotation=0, labelpad=0, fontsize=6)
                ax.yaxis.set_label_coords(0.05, 0.4)
                ax.tick_params(axis='both', which='both', left=False, bottom=False, labelleft=False, labelbottom=False)
                for spine in ax.spines.values():
                    spine.set_visible(False)

        elif 2 <= row < 6:  # Next four rows for segment measurements
            for idx in range(3):
                actual_idx = 3 * (row - 2) + idx
                if actual_idx < 10:  # Ensure we don't exceed the 10 graphs
                    ax = ax_main.inset_axes([0, 1 - (idx+1)/3, 1, 1/3])
                    ax.plot(segment_measurements[4:, actual_idx], marker='.', markevery=(4, 5), markeredgewidth=1, markeredgecolor='black')
                    ax.set_ylim([-0.1, 1.5])
                    label = f"CML {actual_idx+1}"
                    coord_text = f"x1={segment_measurements[0, actual_idx]:.2f}, y1={segment_measurements[1, actual_idx]:.2f}\nx2={segment_measurements[2, actual_idx]:.2f}, y2={segment_measurements[3, actual_idx]:.2f}"
                    full_label = f"{label}\n{coord_text}"
                    ax.set_ylabel(full_label, rotation=0, labelpad=0, fontsize=6)
                    ax.yaxis.set_label_coords(0.05, 0.4)
                    ax.tick_params(axis='both', which='both', left=False, bottom=False, labelleft=False, labelbottom=False)
                    for spine in ax.spines.values():
                        spine.set_visible(False)
    # plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, hspace=0.1, wspace=0.05)  # Adjust overall spacing
    """

    plt.tight_layout()
    plt.show()

In [None]:
%%timeit -n 1 -r 1
for _ in range(1000):
  image = spatialized_gt(ndiscs=4, size=64, pseudo_size=172, nsteps=12, stratification='occurrence', centered=True, freq_occurrence=f)
  i += image[:,1].max()>0
  # print(image[:,1].max()>0)
print(i/1000)

In [None]:
S = 172
# in the dataset :
# image = spatialized_gt(ndiscs=4, size=S, nsteps=12)
# image = spatialized_gt(ndiscs=4, size=64, pseudo_size=172, nsteps=12, stratification='occurrence', centered=True, freq_reports=0.01)
image, reports = spatialized_gt(ndiscs=4, size=172, pseudo_size=172, nsteps=12, stratification=None, centered=False, freq_reports=0.01)
device = torch.device('cpu')
images = torch.tensor(image).unsqueeze(0).float().to(device)


plot_images(image, reports.repeat(), 0)
