<a href="https://colab.research.google.com/github/nanopiero/fusion/blob/main/notebooks/fcns/training_A2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## A2.radar + cmls -> pluvios1min + pluvios60min [xrl_yg1g60]

In [None]:
! git clone https://github.com/nanopiero/fusion.git

Cloning into 'fusion'...
remote: Enumerating objects: 192, done.[K
remote: Counting objects: 100% (185/185), done.[K
remote: Compressing objects: 100% (88/88), done.[K
remote: Total 192 (delta 116), reused 148 (delta 96), pack-reused 7[K
Receiving objects: 100% (192/192), 9.78 MiB | 10.53 MiB/s, done.
Resolving deltas: 100% (116/116), done.


In [None]:
# Imports des bibliothèques utiles
# pour l'IA
import torch
# pour les maths
import numpy as np
# pour afficher des images et des courbes
import matplotlib.pyplot as plt

from random import randint
import os
import time

# imports des fichiers locaux
os.chdir('fusion')
import utile_fusion
# import importlib
# importlib.reload(utile_fusion)

# Import des fonctions génératrices exploitées à l'échelle de l'image
from utile_fusion import spatialized_gt, create_cmls_filter
# Import loading tools
from utile_fusion import FusionDataset
from torch.utils.data import DataLoader
# Import des fonctions utilisées à l'échelle du batch, sur carte GPU
from utile_fusion import indices_to_sampled_values, get_point_measurements, point_gt, segment_gt, make_noisy_images
# Import cost functions
from utile_fusion import QPELoss_fcn, compute_metrics
# Import des fonctions de visualisation
from utile_fusion import set_tensor_values2, plot_images, plot_images_10pts_20seg, plot_results_10pts_20seg


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# A2. Radar + CMLS -> pluvios 1min  [xrl_yp1p60]

In [None]:
# config de base (change en B.):
npoints = 10
npairs = 20
nsteps = 60
ndiscs = 5
size_image=64
length_dataset = 6400
device = torch.device('cuda:0')

In [None]:
# Petit UNet
from utile_fusion import UNet
ch_in = 72
ch_out = nsteps * 3 + 1
size = nsteps * 3

model = UNet(ch_in, ch_out, size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [None]:
criterion = QPELoss_fcn(cumuls_1h=True)

# Baseline with a FCN
use_fcn = True

best_loss = [float('inf'), float('inf')]  # Initialize best validation loss to a very high value
train_losses = []

In [None]:
path = r'/content/drive/MyDrive/rainCell/fusion/models/checkpoint_fcn_A2_xrl_yg1g60.pt'
"""
checkpoint = torch.load(path, \
                            map_location=device)
last_epoch = checkpoint['epoch']
train_losses = checkpoint['train_losses']
# best_loss = checkpoint['best_loss']
model_weights = checkpoint['model']
optimizer_state_dict = checkpoint['optimizer']
"""


In [None]:
model.train()
for epoch in range(100):

  running_regression_loss = 0.0
  running_regression_loss_1h = 0.0
  running_segmentation_loss = 0.0
  train_confusion_matrix = np.zeros((2, 2), dtype=int)
  for i, (images, pairs, filters) in enumerate(loader):

    # ground truth (not usable)
    images = images.clone().detach().float().to(device)

    # pseudo CMLs
    pairs = pairs.clone().detach().float().to(device)
    filters = filters.clone().float().detach().to(device)

    # for transformers :
    # segment_measurements = segment_gt(images, pairs, filters)
    _, segment_measurements_fcn = segment_gt(images, pairs, filters,
                                             use_fcn=use_fcn)

    # pseudo pluvios
    _, point_measurements_fcn, _ = point_gt(images, npoints=npoints,
                                            use_fcn=use_fcn)


    # pseudo radar
    noisy_images = make_noisy_images(images)

    # prepare inputs and targets
    inputs = torch.cat([noisy_images, segment_measurements_fcn], dim=1)
    targets = point_measurements_fcn


    optimizer.zero_grad()  # Zero the gradients
    outputs = model(inputs)  # Forward pass

    regression_loss, regression_loss_1h, segmentation_loss, loss, batch_cm = criterion(model.p, outputs, targets)
    loss.backward()  # Backward pass
    optimizer.step()  # Update the weights

    del inputs, targets, outputs, loss, noisy_images, images, pairs, filters
    torch.cuda.empty_cache()

    running_regression_loss += regression_loss
    running_regression_loss_1h += regression_loss_1h
    running_segmentation_loss += segmentation_loss
    train_confusion_matrix += batch_cm

  # Calculating average training loss
  train_regression_loss = running_regression_loss / len(loader)
  train_regression_loss_1h = running_regression_loss_1h / len(loader)
  train_segmentation_loss = running_segmentation_loss / len(loader)
  train_losses.append((epoch, train_regression_loss, train_regression_loss_1h, train_segmentation_loss, train_confusion_matrix))
  print(f'Training, Regression Loss: {train_regression_loss:.4f}, Regression Loss 1h: {train_regression_loss_1h:.4f}, Segmentation Loss:{train_segmentation_loss:.4f}' )
  print("Train Confusion Matrix:")
  print(train_confusion_matrix)
  accuracy, csi, sensitivity, specificity, false_alarm_ratio = compute_metrics(train_confusion_matrix)
  print(f'Accuracy: {accuracy:.4f}, CSI: {csi:.4f}, Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}, False Alarm Ratio: {false_alarm_ratio:.4f}')
  print('\n')


In [None]:
#100 époques : 2h28

# checkpoint = {
#     'epoch': epoch,
#     'model': model.state_dict(),
#     'optimizer': optimizer.state_dict(),
#     # 'scheduler': scheduler.state_dict(),
#     'train_losses': train_losses,
#     }
# torch.save(checkpoint, path)

In [None]:
path = r'/../models/checkpoint_fcn_A2_xrl_yg1g60.pt'
checkpoint = torch.load(path, \
                            map_location=device)
last_epoch = checkpoint['epoch']
train_losses = checkpoint['train_losses']
# best_loss = checkpoint['best_loss']
model_weights = checkpoint['model']
optimizer_state_dict = checkpoint['optimizer']

In [None]:
# Courbe d'apprentissage
i = -1
j = 1 # CSI
csi_values = [compute_metrics(x[i])[j] for x in train_losses]
plt.plot(csi_values)

In [None]:
# Tracé output
model.eval()

with torch.no_grad():

  running_regression_loss = 0.0
  running_segmentation_loss = 0.0
  train_confusion_matrix = np.zeros((2, 2), dtype=int)

  for i, (images, pairs, filters) in enumerate(loader):

    # ground truth (not usable)
    images = images.clone().detach().float().to(device)

    # pseudo CMLs
    pairs = pairs.clone().detach().float().to(device)
    filters = filters.clone().float().detach().to(device)

    # generation point and segment measurements
    # segment_measurements = segment_gt(images, pairs, filters)
    segment_measurements, segment_measurements_fcn = segment_gt(images, pairs, filters, use_fcn=use_fcn)

    # pseudo pluvios
    point_measurements, point_measurements_fcn, (indices, rows, cols) = \
                        point_gt(images, npoints=npoints, use_fcn=use_fcn)

    # pseudo radar
    noisy_images = make_noisy_images(images)

    # prepare inputs and targets
    inputs = torch.cat([noisy_images, segment_measurements_fcn], dim=1)
    targets = point_measurements_fcn
    outputs = model(inputs)
    mask_rnr = outputs[:, :nsteps,...] < outputs[:, nsteps:2*nsteps,...]
    images_pred = (mask_rnr * outputs[:, 2*nsteps:3*nsteps, ...]).detach()

    # segment_measurements = segment_gt(images, pairs, filters)
    segment_measurements_pred, _ = segment_gt(images_pred,
                                              pairs,
                                              filters,
                                              use_fcn=use_fcn)

    # pseudo pluvios
    sampled_values_pred = indices_to_sampled_values(images_pred, indices)
    point_measurements_pred = get_point_measurements(rows, cols,
                                                     sampled_values_pred,
                                                     size_image)

    break



k=0

plot_results_10pts_20seg(3*images[k,...].cpu().numpy() + filters[k,...].cpu().numpy().sum(axis=0),
                         noisy_images[k,...].cpu().numpy(),
                         point_measurements[k,...].cpu().numpy(),
                         segment_measurements[k,...].cpu().numpy(),
                         3*images_pred[k,...].cpu().numpy() + filters[k,...].cpu().numpy().sum(axis=0),
                         (images_pred[k, torch.arange(4, 60, 5), ...] > 0).long().cpu().numpy(),
                         point_measurements_pred[k,...].cpu().numpy(), #_pred)
                         segment_measurements_pred[k,...].cpu().numpy()) #_pred)
