<a href="https://colab.research.google.com/github/javigallego4/TFG/blob/main/Evaluaci%C3%B3n_de_modelos_ad_hoc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

from IPython.display import clear_output, display_html
import gc; gc.enable()
import warnings
import os
from pathlib import Path
from tqdm import tqdm

# Basic libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy as sc
from scipy import stats
import random
import cv2

# Preprocessing libraries
from sklearn.preprocessing import *
import cv2
import albumentations as A

# Library for .tiff files
!pip install tifffile
import tifffile as tiff

# Timm Library
!pip install timm
import timm

# PyTorch
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image

# MaskRCNN class imports
from typing import Any, Callable, Optional
from torchvision.models.detection.mask_rcnn import _resnet_fpn_extractor
from torch import nn
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import misc as misc_nn_ops
from torchvision.transforms._presets import ObjectDetection

from torchvision.models.detection.mask_rcnn import MaskRCNN, MaskRCNNPredictor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models import *
from torchvision.models.detection.mask_rcnn import _resnet_fpn_extractor

# Deep Lab V3 Backbones
from torchvision.models.segmentation.deeplabv3 import *
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, DeepLabV3_ResNet101_Weights

# Metric (mAP)
!pip install torchmetrics
from torchmetrics.detection.mean_ap import MeanAveragePrecision

# Weights and biases
!pip install wandb
import wandb

# Memory usage
import gc
def gc_collect():
    gc.collect()
    torch.cuda.empty_cache()

warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

os.environ["WANDB_API_KEY"] = '5bf911e7e682da23240c68fb146a222bf0475f7c'
wandb.login() # 5bf911e7e682da23240c68fb146a222bf0475f7c

clear_output()
print('Number of CPUs: ', os.cpu_count())

DEBUG = False
LOG_IMAGES = False
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
train_validation_test_split = pd.read_csv('/content/gdrive/MyDrive/train_validation_test_split.csv', index_col = False)

"""Nos creamos los índices para los conjuntos de entrenamiento, validación y test."""

X_train = train_validation_test_split[train_validation_test_split['set'] == 'train'].image.values
X_test = train_validation_test_split[train_validation_test_split['set'] == 'test'].image.values
X_val = train_validation_test_split[train_validation_test_split['set'] == 'val'].image.values

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor

# Obtenemos todos los nºs de polígonos que nos dan en la siguiente ruta.

BASE_PATH = "/content/gdrive/MyDrive/juniper/WV3/"
polygon_numbers = os.listdir(BASE_PATH)
polygon_numbers = pd.Series(polygon_numbers).str.split('_', n = 2, expand = True)[1]
polygon_numbers = list(polygon_numbers)
polygon_numbers = sorted(polygon_numbers)

# Guardamos en arrays cada una de las imágenes y máscaras.

def load_images(polygon_numbers):
  for polygon_number in polygon_numbers:
      # Panchromatic Images
      p_images.append(tiff.imread(BASE_PATH + "polygon_{}/panchromatic.tif".format(polygon_number)))
      p_masks.append(tiff.imread(BASE_PATH + "polygon_{}/mask_panchromatic.tif".format(polygon_number)))

      # Multispectral Images
      # Hacemos un permute para poner las imágenes en el formato PyTorch
      m_images.append(tiff.imread(BASE_PATH + "polygon_{}/multispectral.tif".format(polygon_number)))
      m_masks.append(tiff.imread(BASE_PATH + "polygon_{}/mask_multispectral.tif".format(polygon_number)))

"""## Pansharpening"""

def histogram_match(pan, band):
    """
    Performs histogram matching between the panchromatic image and the multispectral band given.

    Parameters:
    - pan: torch tensor of shape (height, width)
    - band: torch tensor of shape (height, width)

    Returns:
    - matched_panchromatic: histogram matched PAN imagery
    """

    # Fórmula UGR: https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&cad=rja&uact=8&ved=2ahUKEwicxYCXrYv9AhXRhqQKHYoUDFQQFnoECAkQAQ&url=https%3A%2F%2Fccia.ugr.es%2Fvip%2Ffiles%2Fbooks%2Fpaper_amro_mateos.pdf&usg=AOvVaw3wn01QiErGCJLtNZUg-oKe
    matched_panchromatic = (pan - pan.mean()) * (band.std() / pan.std()) + band.mean()

    return matched_panchromatic

def pansharpen_image(multispectral_image, panchromatic_image, method):
    """
    Pansharpens the given MS image based on different techniques.

    Parameters:
    - multispectral_image: torch tensor of shape (channels, height, width)
    - panchromatic_image: torch tensor of shape (height, width)
    - method: type of pansharpening technique

    Returns:
    - sharpened_img: torch tensor with same shape as input multispectral image
    """

    if method == 'Simple Mean':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      sharpened_img = torch.randn(multispectral_image.shape[0], multispectral_image.shape[1], multispectral_image.shape[2])
      for i in range(multispectral_image.shape[0]):
        # Histogram Matching for each band
        matched_panchromatic = histogram_match(panchromatic_image, multispectral_image[i, :, :])
        sharpened_img[i, :, :] = 0.5 * (multispectral_image[i, :, :] + matched_panchromatic)

    if method == 'Brovey':
        multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
        sharpened_img = torch.randn(multispectral_image.shape[0], multispectral_image.shape[1], multispectral_image.shape[2])

        M = 0
        for i in range(multispectral_image.shape[0]):
          M += multispectral_image[i, :, :]
        M /= multispectral_image.shape[0]

        for i in range(multispectral_image.shape[0]):
          # Histogram Matching for each band
          matched_panchromatic = histogram_match(panchromatic_image, multispectral_image[i, :, :])
          sharpened_img[i, :, :] = (matched_panchromatic / M) * multispectral_image[i, :, :]

    if method == 'HSV':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 0.577 * (red+green+blue)
      v1 = -0.408 * (red+green) + 0.816 * blue
      v2 = -0.707 * (red+green) + 1.703 * blue

      H = torch.atan(v2/v1)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      new_red = 0.577 * matched_panchromatic - 0.408 * v1 - 0.707 * v2
      new_green = 0.577 * matched_panchromatic - 0.408 * v1 - 0.816 * v2
      new_blue = 0.577 * matched_panchromatic - 0.816 * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    if method == 'IHS1':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 1/np.sqrt(3) * (red+green+blue)
      v1 = -1/np.sqrt(6) * (red+green) + 2/np.sqrt(6) * blue
      v2 = 1/np.sqrt(2) * (-red+green)

      H = torch.atan(v2/v1)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      v1 = S * torch.cos(H)
      v2 = S * torch.sin(H)

      new_red = 1/np.sqrt(3) * matched_panchromatic -1/np.sqrt(6) * v1 - 1/np.sqrt(2) * v2
      new_green = 1/np.sqrt(3) * matched_panchromatic -1/np.sqrt(6) * v1 + 1/np.sqrt(2) * v2
      new_blue = 1/np.sqrt(3) * matched_panchromatic + 2/np.sqrt(6) * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    if method == 'IHS2':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 1/3 * (red+green+blue)
      v1 = -1/np.sqrt(6) * (red+green) + 2/np.sqrt(6) * blue
      v2 = 1/np.sqrt(6) * red - 2 /np.sqrt(6) * green

      H = torch.atan(v2/v1)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      v1 = S * torch.cos(2*np.pi*H)
      v2 = S * torch.sin(2*np.pi*H)

      new_red = 1 * matched_panchromatic -0.204124 * v1 - 0.612372 * v2
      new_green = 1 * matched_panchromatic -0.204124 * v1 + 0.612372 * v2
      new_blue = 1 * matched_panchromatic + 0.408248 * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    if method == 'IHS3':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 1/3 * (red+green+blue)
      v1 = -1/np.sqrt(6) * (red+green) + 2/np.sqrt(6) * blue
      v2 = 1/np.sqrt(6) * red - 1/np.sqrt(6) * green

      H = torch.atan(v2/v1)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      new_red = 1 * matched_panchromatic -1/np.sqrt(6) * v1 +3/np.sqrt(6) * v2
      new_green = 1 * matched_panchromatic -1/np.sqrt(6) * v1 -3/np.sqrt(6) * v2
      new_blue = 1 * matched_panchromatic + 2/np.sqrt(6) * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    if method == 'IHS4':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 1/3 * (red+green+blue)
      v1 = 1/np.sqrt(6) * (red+green) - 2/np.sqrt(6) * blue
      v2 = 1/np.sqrt(2) * red - 1/np.sqrt(2) * green

      H = torch.atan(v1/v2)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      new_red = 1/3 * matched_panchromatic +1/np.sqrt(6) * v1 + 1/np.sqrt(2) * v2
      new_green = 1/3 * matched_panchromatic +1/np.sqrt(6) * v1 -1/np.sqrt(2) * v2
      new_blue = 1/3 * matched_panchromatic - 1/np.sqrt(2) * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    if method == 'IHS5':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 1/3 * (red+green+blue)
      v1 = 1/np.sqrt(6) * (red+green) - 2/np.sqrt(6) * blue
      v2 = 1/np.sqrt(2) * red - 1/np.sqrt(2) * green

      H = torch.atan(v2/v1)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      new_red = 1 * matched_panchromatic +1/np.sqrt(6) * v1 + 1/np.sqrt(2) * v2
      new_green = 1 * matched_panchromatic +1/np.sqrt(6) * v1  -1/np.sqrt(2) * v2
      new_blue = 1 * matched_panchromatic - 2/np.sqrt(6) * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    if method == 'IHS6':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 1/3 * (red+green+blue)
      v1 = -2/np.sqrt(6) * (red+green) + 2/np.sqrt(6) * blue
      v2 = 1/np.sqrt(2) * red - 1/np.sqrt(2) * green

      H = torch.atan(v2/v1)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      new_red = 1 * matched_panchromatic -1/np.sqrt(2) * v1 + 1/np.sqrt(2) * v2
      new_green = 1 * matched_panchromatic -1/np.sqrt(2) * v1  -1/np.sqrt(2) * v2
      new_blue = 1 * matched_panchromatic + np.sqrt(2) * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    if method == 'HLS':
      multispectral_image = torchvision.transforms.Resize((panchromatic_image.shape[0], panchromatic_image.shape[1]))(multispectral_image)
      # Forward Transform
      red = multispectral_image[4, :, :]
      green = multispectral_image[2, :, :]
      blue = multispectral_image[1, :, :]

      I = 1/3 * (red+green+blue)
      v1 = 1/np.sqrt(6) * (red+green) - 2/np.sqrt(6) * blue
      v2 = 1/np.sqrt(2) * red - 1/np.sqrt(2) * green

      H = torch.atan(v1/v2)
      S = torch.sqrt(torch.pow(v1,2) + torch.pow(v2,2))

      # Histogram Matching
      matched_panchromatic = histogram_match(panchromatic_image, I)

      # Reverse Transformation
      new_red = 1 * matched_panchromatic +1/np.sqrt(6) * v1 + 1/np.sqrt(2) * v2
      new_green = 1 * matched_panchromatic +1/np.sqrt(6) * v1  -1/np.sqrt(2) * v2
      new_blue = 1 * matched_panchromatic - 2/np.sqrt(6) * v1

      sharpened_img = multispectral_image
      sharpened_img[4, :, :] = new_red
      sharpened_img[2, :, :] = new_green
      sharpened_img[1, :, :] = new_blue

    return sharpened_img

def pansharpening(method):
  """
  Performs pansharpening to every image in the dataset.
  """

  for i in range(len(p_images)):
    img = pansharpen_image(torch.from_numpy(m_images[i]).permute(2,0,1), torch.from_numpy(p_images[i]), method)
    m_images[i] = img.permute(1,2,0).numpy()

"""## Scaling"""

def scale_panchromatic_image(image, transformer = MinMaxScaler()):
  '''Returns input panchromatic image with its values being scaled to the [0,1] interval. '''

  img = transformer.fit_transform(image, )
  return img

def scale_multispectral_image(image, bands = 8, transformer = MinMaxScaler()):
  '''Returns input multispectral image with its values being scaled to the [0,1] interval. '''

  b0 = image[:,:,0]
  b1 = image[:,:,1]
  b2 = image[:,:,2]
  if bands == 8:
    b3 = image[:,:,3]
    b4 = image[:,:,4]
    b5 = image[:,:,5]
    b6 = image[:,:,6]
    b7 = image[:,:,7]

  # As before, we perform some scaling first
  sc = transformer
  b0 = sc.fit_transform(b0)
  b1 = sc.fit_transform(b1)
  b2 = sc.fit_transform(b2)
  if bands == 8:
    b3 = sc.fit_transform(b3)
    b4 = sc.fit_transform(b4)
    b5 = sc.fit_transform(b5)
    b6 = sc.fit_transform(b6)
    b7 = sc.fit_transform(b7)

  if bands == 8: img = np.dstack([b0, b1, b2, b3, b4, b5, b6, b7])
  else: img = np.dstack([b0, b1, b2])
  return img

def scaling():
  '''Pipeline function for value scaling. '''

  for i in range(len(p_images)):
    p_images[i] = scale_panchromatic_image(p_images[i])
    m_images[i] = scale_multispectral_image(m_images[i])

"""## Data Augmentation"""

def HorizontalFlip():
  '''Performs horizontal flipping. '''

  for i in range(len(polygon_numbers)):
    m_images.append(cv2.flip(m_images[i], 1))
    p_masks.append(cv2.flip(p_masks[i], 1))

def VerticalFlip():
  '''Performs vertical flipping. '''

  for i in range(len(polygon_numbers)):
    m_images.append(cv2.flip(m_images[i], 0))
    p_masks.append(cv2.flip(p_masks[i], 0))

def VHFlip():
  '''Performs both horizontal and vertical flipping. '''

  for i in range(len(polygon_numbers)):
    m_images.append(cv2.flip(m_images[i], -1))
    p_masks.append(cv2.flip(p_masks[i], -1))

def Rotation90():
  '''Performs a 90 degrees rotation on the images'''

  for i in range(len(polygon_numbers)):

    # Transpose the image
    image = image.transpose(1,0,2)
    # Flip the image vertically
    image = cv2.flip(image, 1)
    m_images.append(cv2.flip(m_images[i], -1))
    p_masks.append(cv2.flip(p_masks[i], -1))

# source: https://www.kaggle.com/safavieh/image-augmentation-using-skimage
import random
import pylab as pl

def random_crop(img, mask, crop_height, crop_width):

    height, width = img.shape[0], img.shape[1]

    # Calculate aspect ratio
    aspect_ratio = float(width / height)

    # Calculate the maximum width and height that can be cropped while maintaining the aspect ratio
    max_crop_width = int(aspect_ratio * crop_height)
    max_crop_height = int(crop_width / aspect_ratio)

    # Choose a random starting point for the crop
    start = random.randint(0,10)
    x = random.randint(start, width - max_crop_width)
    y = random.randint(start, height - max_crop_height)

    # Crop the image and mask
    cropped_img = img[y:y+max_crop_height, x:x+max_crop_width]
    cropped_mask = mask[y:y+max_crop_height, x:x+max_crop_width]

    # Resize the cropped image and mask to the desired size
    # Interpolation. Possible values: cv2.INTER_LINEAR, cv2.INTER_NEAREST, cv2.INTER_AREA, cv2.INTER_CUBIC
    resized_image = cv2.resize(cropped_img, (width, height), interpolation=cv2.INTER_LINEAR)
    resized_mask = cv2.resize(cropped_mask, (width, height), interpolation=cv2.INTER_LINEAR)

    return resized_image, resized_mask

def RandomCropping():
    ''' Applying 10 random croppings to all the images.'''

    for i in range(len(polygon_numbers)):
      for j in range(10):
          width = random.randint(40, 60)
          aspect_ratio = m_images[i].shape[0] / m_images[i].shape[1]
          img, mask = random_crop(m_images[i], p_masks[i], int(width*aspect_ratio), width)
          m_images.append(img)
          p_masks.append(mask)

def data_augmentation():
  '''Performs data augmentation over images and masks. '''

  HorizontalFlip()
  VerticalFlip()
  VHFlip()
  RandomCropping()

"""## Padding"""

def pad_images(imgs, msks, border):
  '''Pipeline function for images' padding. '''

  border_type = border
  images, masks = [], []
  for i in range(len(imgs)):
    images.append(cv2.copyMakeBorder(imgs[i], 128 - imgs[i].shape[0], 0, 80 - imgs[i].shape[1], 0, border_type))
    masks.append(cv2.copyMakeBorder(msks[i], 128 - msks[i].shape[0], 0, 80 - msks[i].shape[1], 0, border_type))

  return images, masks

"""## Preprocessing Pipeline Main Function"""

def preprocessing_pipeline(method, border_type = cv2.BORDER_CONSTANT, scale = True):
  for i in range(len(m_images)):
    m_images[i] = m_images[i].astype(float)
    p_images[i] = p_images[i].astype(float)

  if scale: scaling()
  pansharpening(method)
  data_augmentation()
  imgs, masks = pad_images(m_images, p_masks, border_type)

  for i in range(len(imgs)):
    masks[i] = masks[i].astype(int)

  return imgs, masks

"""# PyTorch Dataset"""

class PanchromaticDataset(torch.utils.data.Dataset):
    def __init__(self, images, masks):
      super().__init__()
      self.images = images
      self.masks = masks

    def __getitem__(self, i):
      return torch.from_numpy(self.images[i].astype(float)), torch.from_numpy(self.masks[i].astype(float))

    def __len__(self):
      return len(self.images)

class MultispectralDataset(torch.utils.data.Dataset):
  def __init__(self, images, masks):
    super().__init__()
    self.images = images
    self.masks = masks

  def __getitem__(self, i):
    img, mask = torch.from_numpy(self.images[i].astype(float)).permute(2,0,1), torch.from_numpy(self.masks[i].astype(float))
    return torch.nan_to_num(img), torch.nan_to_num(mask)

  def __len__(self):
      return len(self.images)

"""# Model"""

"""### Model Implementation"""

import torchvision.models as models
from torchvision.models.detection import maskrcnn_resnet50_fpn, maskrcnn_resnet50_fpn_v2

"""# Model

### Model Implementation
"""

from torchvision.models import *
from torch.nn.modules import batchnorm

class ConvBlock(nn.Module):
    ''' Basic block for performing two convolution operations. '''

    def __init__(self, in_channels, out_channels, pad = 1, activation_function = nn.ReLU(inplace=True)):
        super(ConvBlock, self).__init__()
        conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride = 1, padding=pad)
        batch1 = nn.BatchNorm2d(out_channels)

        conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride = 1, padding=pad)
        batch2 = nn.BatchNorm2d(out_channels)

        self.conv_block = nn.Sequential(conv1, batch1, activation_function, conv2, batch2, activation_function)

    def forward(self, x):
        x = self.conv_block(x)
        return x

class unetDown(nn.Module):
    ''' Encoder block of the U-Net architecture. '''

    def __init__(self, in_channels, out_channels, pad = 1):
        super(unetDown, self).__init__()
        self.conv_block = ConvBlock(in_channels, out_channels, pad)
        self.pooling = nn.MaxPool2d(kernel_size = 2, stride = 2)

    def forward(self, x, indices = None):
        skip_connection = self.conv_block(x)
        x = self.pooling(skip_connection)
        return x, skip_connection

class unetUp(nn.Module):
    ''' Decoder block of the U-Net architecture. '''

    def __init__(self, in_channels, out_channels, pad=1):
        super(unetUp, self).__init__()
        self.upsample = nn.ConvTranspose2d(in_channels-out_channels, in_channels-out_channels, kernel_size=2, stride=2)
        self.conv_block = ConvBlock(in_channels, out_channels, pad)

    def forward(self, x, skip_connection):
        x = self.upsample(x)
        #print('nnnn')
        #print(x.shape)
        #print(skip_connection.shape)
        x = torch.cat([x, skip_connection], dim=1)
        x = self.conv_block(x)
        return x

class multispectralUnet(nn.Module):
    ''' Full architecture of the proposed network. '''

    def __init__(self, n_filters_1, n_filters_2, n_filters_3, n_filters_4):
        super(multispectralUnet, self).__init__()

        # ====== Encoder ======

        self.down1 = unetDown(8, n_filters_1)
        self.down2 = unetDown(n_filters_1, n_filters_2)
        self.down3 = unetDown(n_filters_2, n_filters_3)

        # ====== BottleNeck =====
        self.bottleneck = ConvBlock(n_filters_3, n_filters_4)

        # ===== Decoder =====

        self.up3 = unetUp(n_filters_3 + n_filters_4, n_filters_3)
        self.up2 = unetUp(n_filters_2 + n_filters_3, n_filters_2)
        self.up1 = unetUp(n_filters_1 + n_filters_2, n_filters_1)

        # Final Convolution
        self.conv_last = nn.Sequential(nn.Conv2d(n_filters_1, 1, kernel_size=1), nn.Tanh())

    def forward(self, x):
        x, skip1 = self.down1(x)
        x, skip2 = self.down2(x)
        x, skip3 = self.down3(x)

        x = self.bottleneck(x)

        x = self.up3(x, skip3)
        x = self.up2(x, skip2)
        x = self.up1(x, skip1)
        x = self.conv_last(x)
        return x

    def predict(self,x):
        x = self.forward(x)
        return x

"""# Training Pipeline

### Helper Functions
"""

from torchmetrics import F1Score, Precision, Recall

def optimal_f1(predictions, targets):
    '''With this function we obtain the optimal threshold, given our model predictions. '''
    thres = np.linspace(torch.min(predictions).item(), torch.max(predictions).item(), 201)
    f1_score = F1Score(task="binary", num_classes=2).to(device)
    f1s = torch.Tensor([f1_score(predictions > thr, targets) for thr in thres])
    idx = torch.argmax(f1s)
    return f1s[idx], thres[idx]

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

def validate_one_epoch(validation_loader, model, thr):
    #metric = MeanAveragePrecision(iou_type = 'bbox')
    loss_epoch = 0
    bce_loss = nn.BCEWithLogitsLoss().to(device)
    f1_epoch, precision_epoch, recall_epoch = 0.0, 0.0, 0.0
    f1_score = F1Score(task="binary", num_classes=2).to(device)
    precision_score = Precision(task="binary", num_classes=2).to(device)
    recall_score = Recall(task="binary", num_classes=2).to(device)

    # Don't update weights
    with torch.no_grad():

      # Loop over minibatches
      for imgs, masks in tqdm(validation_loader):
          model.train()
          imgs = imgs.to(device, dtype=torch.float32)
          masks = masks.to(device)

          # Make predictions and obtain losses
          with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = bce_loss(outputs.squeeze()[:, 29:, :], (masks[:, 29:, :]).to(device))

          # Track losses
          loss_epoch += loss.detach().item()

          # Track metric
          model.eval()
          preds = model.predict(imgs).squeeze()

          f1_epoch += f1_score(preds[:, 29:, :] > thr, masks.squeeze()[:, 29:, :])
          recall_epoch += recall_score(preds[:, 29:, :] > thr, masks.squeeze()[:, 29:, :])
          precision_epoch += precision_score(preds[:, 29:, :] > thr, masks.squeeze()[:, 29:, :])

    return loss_epoch/len(validation_loader), f1_epoch/len(validation_loader), precision_epoch/len(validation_loader), recall_epoch/len(validation_loader)

def show_predictions(loader, model, thr):
  # Ejemplo de la máscara predicha
  # Don't update weights
    with torch.no_grad():
      for imgs, masks in tqdm(loader):
        model.eval()

        # Create MaskRCNN inputs and targets
        inputs = [imgs[i].to(device, dtype = torch.float32) for i in range(imgs.shape[0])]

        # Track metric
        pred = model.predict(inputs)

        for i in range(imgs.shape[0]):
          predicted_mask = torch.sum(pred[i]['masks'], dim=0)

          fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8,4))
          axes[0].imshow(masks[i])
          axes[1].imshow((predicted_mask > thr).cpu().squeeze())
          plt.show()

        del inputs, pred
        gc_collect()

def validate_model(verbose=True):
    torch.manual_seed(42)
    # Init W&B
    run = wandb.init(entity="javigallego4", project="Bachelor Thesis", group = 'Baseline - Ensembling')

    # Preprocessing Pipeline (No initial scaling)
    m_images2, p_masks2 = preprocessing_pipeline('Simple Mean', cv2.BORDER_REFLECT_101, False)

    # Datasets and train-test-split
    multispectral_dataset = MultispectralDataset(m_images2, p_masks2, 'Vegetation Indexes')
    n = len(multispectral_dataset)
    training_data = torch.utils.data.Subset(multispectral_dataset, X_train)
    validation_data = torch.utils.data.Subset(multispectral_dataset, X_val)
    test_data = torch.utils.data.Subset(multispectral_dataset, X_test)
    gc_collect()

    print('Dataset created')

    # Load artifact
    run_name = 'young-sweep-49'
    artifact = run.use_artifact('javigallego4/Bachelor Thesis/{}:v0'.format(run_name), type='model')
    artifact_dir = artifact.download()
    # Load the model
    model = maskrcnn_model('resnext101_32x8d').to(device)
    model.load_state_dict(torch.load('/content/artifacts/{}:v0/{}.pt'.format(run_name, run_name)))
    best_thr = 0.589

    print('Model loaded')

    batch = 4
    trainloader = torch.utils.data.DataLoader(training_data, batch_size = batch, num_workers = 2, shuffle = True, pin_memory=True)
    validationloader = torch.utils.data.DataLoader(validation_data, batch_size = batch, num_workers = 2, pin_memory=True)
    testloader = torch.utils.data.DataLoader(test_data, batch_size = batch, num_workers = 2, pin_memory=True)

    print('Preprocessing done')

    # Evaluate
    #val_f1, val_precision, val_recall = get_validation_scores(validationloader, model, best_thr)
    #test_f1, test_precision, test_recall = get_validation_scores(testloader, model, best_thr)

    # Print loss
    #print(f'val_f1 {val_f1:.5f}, val_precision {val_precision:.5f}, val_recall {val_recall:.5f}')
    #print(f'test_f1 {test_f1:.5f}, test_precision {test_precision:.5f}, test_recall {test_recall:.5f}')

    show_predictions(testloader, model, best_thr)

    run.finish()

Number of CPUs:  4


In [2]:
#p_images, m_images, p_masks, m_masks = [], [], [], []
#fit_all_images(polygon_numbers)
#validate_model()

In [3]:
def get_validation_masks(validation_loader, model):
    masks_list = []
    with torch.no_grad():
      # Loop over minibatches
      for imgs, masks in tqdm(validation_loader):
          model.eval()

          # Create MaskRCNN inputs and targets
          inputs = imgs.to(device, dtype=torch.float32)
          pred = model.predict(inputs)

          for i in range(inputs.shape[0]):
            masks_list.append(pred[i, :, 29:, :].squeeze())

          del inputs, pred
          gc_collect()

    masks_pred = torch.stack(masks_list, 0)
    return masks_pred

def get_ensembling_scores(verbose=True):
    torch.manual_seed(42)
    # Init W&B
    run = wandb.init(entity="javigallego4", project="Bachelor Thesis", group = 'Baseline - Ensembling')

    # Preprocessing Pipeline (No initial scaling)
    m_images2, p_masks2 = preprocessing_pipeline('Simple Mean', cv2.BORDER_REFLECT_101, False)

    # Datasets and train-test-split
    multispectral_dataset = MultispectralDataset(m_images2, p_masks2, 'Vegetation Indexes')
    n = len(multispectral_dataset)
    training_data = torch.utils.data.Subset(multispectral_dataset, X_train)
    validation_data = torch.utils.data.Subset(multispectral_dataset, X_val)
    test_data = torch.utils.data.Subset(multispectral_dataset, X_test)
    gc_collect()

    print('Dataset created')

    batch = 16
    trainloader = torch.utils.data.DataLoader(training_data, batch_size = batch, num_workers = 2, shuffle = True, pin_memory=True)
    validationloader = torch.utils.data.DataLoader(validation_data, batch_size = batch, num_workers = 2, pin_memory=True)
    testloader = torch.utils.data.DataLoader(test_data, batch_size = batch, num_workers = 2, pin_memory=True)

    print('Preprocessing done')

    # Load artifact
    masks_list = []
    best_runs = ['young-sweep-49','eternal-sweep-52','grateful-sweep-43']
    backbones = ['resnext101_32x8d','resnet101','resnext101_32x8d']
    thresholds = [0.589, 0.5856, 0.5737]
    models = []
    #weights = []
    for i, run_name in enumerate(best_runs):
      # Load artifact
      artifact = run.use_artifact('javigallego4/Bachelor Thesis/{}:v0'.format(run_name), type='model')
      artifact_dir = artifact.download()

      # Load the model
      model = maskrcnn_model(backbones[i]).to(device)
      model.load_state_dict(torch.load('/content/artifacts/{}:v0/{}.pt'.format(run_name, run_name)))
      models.append(model)

      # Evaluate
      #model, thresholds = models[best_runs[i]]
      masks_list.append(get_validation_masks(validationloader, model))

      del model, artifact
      gc_collect()

    ground_truth_masks = []
    validationloader = torch.utils.data.DataLoader(validation_data, batch_size = 1, num_workers = 2, pin_memory=True)
    for imgs, masks in tqdm(validationloader):
      ground_truth_masks.append(masks.squeeze().to(device))

    stacked_masks = torch.stack(masks_list)
    for i in range(100):
      fig, axes = plt.subplots(nrows = 1, ncols = 5, figsize=(20,4))
      axes[0].imshow(stacked_masks[0][i].cpu())
      axes[0].set_title('Modelo 1')
      axes[1].imshow(stacked_masks[1][i].cpu())
      axes[1].set_title('Modelo 2')
      axes[2].imshow(stacked_masks[2][i].cpu())
      axes[2].set_title('Modelo 3')
      axes[4].imshow(ground_truth_masks[i].cpu())
      axes[4].set_title('Mascara original')
      plt.show()

    print(stacked_masks.shape)
    for w3 in np.arange(0.01, 1.0, 0.01):
      for w2 in np.arange(0.01, 1.0, 0.01):
        for w1 in np.arange(0.01, 1.0, 0.01):
            weights = [w1, w2, w3]
            tensor_weights = torch.tensor(weights)
            expanded_weights = tensor_weights[:, None, None, None].to(device)
            print('Pesos shape ', expanded_weights.shape)
            weighted_sum = torch.sum(stacked_masks * expanded_weights, dim=0)
            final_masks = weighted_sum / torch.sum(tensor_weights)
            print(final_masks.shape)

            #ground_truth_masks = torch.stack(masks_list, dim = 0)
            #print(ground_truth_masks.shape)

            f1_score = F1Score(task="binary", num_classes=2).to(device)
            precision_score = Precision(task="binary", num_classes=2).to(device)
            recall_score = Recall(task="binary", num_classes=2).to(device)
            precision_epoch, f1_epoch, recall_epoch = 0, 0, 0

            optimal_thr = 0
            for i in range(len(ground_truth_masks)):
                f1_output, thr = optimal_f1(final_masks[i], ground_truth_masks[i])
                f1_epoch += f1_output
                optimal_thr += thr

            wandb.log({
                'w1': w1,
                'w2': w2,
                'w3': w3,
                'best_f1': f1_epoch/len(validationloader),
                'best_thr': optimal_thr/len(validationloader)
            })

            print(f1_epoch/len(validationloader))
            print(optimal_thr/len(validationloader))



    run.finish()

In [4]:
#gc_collect()
#p_images, m_images, p_masks, m_masks = [], [], [], []
#fit_all_images(polygon_numbers)
#get_ensembling_scores()

In [5]:
# todas son listas
def hill_climbing(predicted_masks, ground_truth_masks, run_names):
  '''Performs hill climbing ensembling technique. '''

  i = 0
  STOP = False
  current_best_ensemble = (0.78 * (0.63 * (0.79 * (0.74 * (0.5 * predicted_masks[0] + 0.5 * predicted_masks[4]) + 0.26 * predicted_masks[3]) + 0.21 * predicted_masks[2]) + 0.37 * predicted_masks[1]) \
                          + 0.22 * predicted_masks[7])  # El mejor ensamblado de las máscaras
  del predicted_masks[0], run_names[0], run_names[4], predicted_masks[4], run_names[3], predicted_masks[3], run_names[2], predicted_masks[2], run_names[1], predicted_masks[1], \
      predicted_masks[7], run_names[7]
  history = [optimal_f1(current_best_ensemble, ground_truth_masks)[0]]
  print('Initial best F1 Score: ', history[0])

  while not STOP:
    i+=1
    # Calculamos el F1 Score del ensamblado de máscaras actual
    potential_new_best_f1_score, potential_new_best_thr = optimal_f1(current_best_ensemble, torch.tensor(ground_truth_masks))
    k_best, wgt_best = None, None
    for k in range(len(predicted_masks)):
        print('=== {} ==='.format(k))
        for wgt in np.arange(-0.51,0.51,0.01):
            # Añadimos las máscaras de un modelo al ensamblado
            potential_ensemble = (1-wgt) * current_best_ensemble + wgt * predicted_masks[k]
            # Calculamos el nuevo F1 Score
            f1_score, thr = optimal_f1(potential_ensemble, torch.tensor(ground_truth_masks))
            # En caso de ser mejor, actualizamos valores
            if f1_score > potential_new_best_f1_score:
                potential_new_best_f1_score = f1_score
                k_best, wgt_best = k, wgt
                best_thr = thr

    if k_best is not None:
          # Si hemos conseguido aumentar el F1 Score, actualizamos el ensamblado de máscaras
          current_best_ensemble = (1-wgt_best) * current_best_ensemble + wgt_best * predicted_masks[k_best]
          # Si ya no quedan más máscaras entonces paramos
          if len(predicted_masks)==0:
              STOP = True
          # !!!!! Actualizar el print.
          print(f'Iteration: {i}, Model added: {run_names[k_best]}, Best weight: {wgt_best:.2f}, Best Thr: {potential_new_best_thr}, Best F1: {potential_new_best_f1_score:.5f}')
          history.append(potential_new_best_f1_score)
          del predicted_masks[k_best], run_names[k_best]
    # Si ningún ensamblado nuevo ha conseguido boostear el F1 Score => paramos
    else:
        STOP = True

  return history

In [None]:
torch.manual_seed(42)
# Init W&B
run = wandb.init(entity="javigallego4", project="Bachelor Thesis", group = 'Baseline - Ensembling')

# Preprocessing Pipeline (No initial scaling)
p_images, m_images, p_masks, m_masks = [], [], [], []
load_images(polygon_numbers)
m_images, p_masks = preprocessing_pipeline('HLS', cv2.BORDER_CONSTANT, False)

# Datasets and train-test-split
multispectral_dataset = MultispectralDataset(m_images, p_masks)
n = len(multispectral_dataset)
validation_data = torch.utils.data.Subset(multispectral_dataset, np.concatenate([X_val, X_train]))
gc_collect()
validationloader = torch.utils.data.DataLoader(validation_data, batch_size = 16, num_workers = 2, pin_memory=True, shuffle=False)
print('Dataset created')

# Load artifact
predicted_masks = []
best_runs = ['morning-sweep-28','scarlet-sweep-159','effortless-sweep-175','cool-sweep-26','amber-sweep-21','zesty-sweep-9','driven-sweep-20','winter-sweep-24','bright-sweep-31','treasured-sweep-18','dark-sweep-66','golden-sweep-20','sleek-sweep-24',\
             'drawn-sweep-64', 'revived-sweep-9', 'atomic-sweep-78', 'electric-sweep-72', 'logical-sweep-10','dulcet-sweep-152','firm-sweep-129', 'honest-sweep-155', 'glamorous-sweep-124', 'kind-sweep-113', 'golden-sweep-174']
n_filters_1 = [128, 96, 96, 512, 64, 64, 512, 512, 64, 128, 64, 128, 256, 512, 32, 16, 256, 32, 128, 96, 96, 32, 512, 32]
n_filters_2 = [512, 64, 96, 512, 512, 512, 32, 256, 256, 256, 128, 512, 128, 128, 16, 64, 64, 16, 128, 256, 64, 256, 128, 96]
n_filters_3 = [512, 512, 512, 64, 16, 512, 512, 64, 96, 128, 512, 64, 512, 128, 32, 512, 256, 32, 512, 512, 512, 512, 128, 512]
n_filters_4 = [16, 256, 96, 64, 16, 32, 256, 16, 64, 96, 256, 128, 256, 256, 512, 96, 96, 256, 96, 16, 64, 16, 16, 96]

#best_thr = [-0.1547, -0.1433, -0.1027]
for i, run_name in enumerate(best_runs):
    # Load artifact
    artifact = run.use_artifact('javigallego4/Bachelor Thesis/{}:v0'.format(run_name), type='model')
    artifact_dir = artifact.download()

    # Load the model
    model = multispectralUnet(n_filters_1[i], n_filters_2[i], n_filters_3[i], n_filters_4[i]).to(device)
    model.load_state_dict(torch.load('/content/artifacts/{}:v0/{}.pt'.format(run_name, run_name)))

    # Evaluate
    predicted_masks.append(get_validation_masks(validationloader, model))
    print('Predicted masks appended')

ground_truth_masks = []
for imgs, masks in tqdm(validationloader):
    ground_truth_masks.append(masks.squeeze().to(device))

ground_truth_masks = torch.cat(ground_truth_masks, dim = 0)[:,29:,:]
hill_climbing(predicted_masks, ground_truth_masks, best_runs)

Dataset created


[34m[1mwandb[0m: Downloading large artifact morning-sweep-28:v0, 86.93MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.6
100%|██████████| 183/183 [01:07<00:00,  2.73it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:42<00:00,  4.31it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:42<00:00,  4.29it/s]


Predicted masks appended


[34m[1mwandb[0m: Downloading large artifact cool-sweep-26:v0, 79.49MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.2
100%|██████████| 183/183 [01:42<00:00,  1.78it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:43<00:00,  4.20it/s]


Predicted masks appended


[34m[1mwandb[0m: Downloading large artifact zesty-sweep-9:v0, 83.99MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.7
100%|██████████| 183/183 [00:41<00:00,  4.37it/s]


Predicted masks appended


[34m[1mwandb[0m: Downloading large artifact driven-sweep-20:v0, 72.89MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.4
100%|██████████| 183/183 [01:02<00:00,  2.95it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [01:17<00:00,  2.37it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:43<00:00,  4.18it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:44<00:00,  4.12it/s]


Predicted masks appended


[34m[1mwandb[0m: Downloading large artifact dark-sweep-66:v0, 50.79MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.1
100%|██████████| 183/183 [00:44<00:00,  4.15it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:43<00:00,  4.23it/s]


Predicted masks appended


[34m[1mwandb[0m: Downloading large artifact sleek-sweep-24:v0, 58.88MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.2
100%|██████████| 183/183 [00:47<00:00,  3.88it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [01:03<00:00,  2.89it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:41<00:00,  4.41it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:40<00:00,  4.54it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:44<00:00,  4.12it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:39<00:00,  4.62it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:42<00:00,  4.33it/s]


Predicted masks appended


[34m[1mwandb[0m: Downloading large artifact firm-sweep-129:v0, 51.09MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.8
100%|██████████| 183/183 [00:44<00:00,  4.13it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:41<00:00,  4.38it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:42<00:00,  4.34it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:54<00:00,  3.36it/s]


Predicted masks appended


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 183/183 [00:42<00:00,  4.35it/s]


Predicted masks appended


100%|██████████| 183/183 [00:02<00:00, 73.56it/s]


Initial best F1 Score:  tensor(0.9311)
=== 0 ===
=== 1 ===
=== 2 ===
=== 3 ===
=== 4 ===
=== 5 ===
=== 6 ===
=== 7 ===
=== 8 ===
=== 9 ===
=== 10 ===
=== 11 ===
=== 12 ===
=== 13 ===
=== 14 ===
=== 15 ===
=== 16 ===
=== 17 ===
Iteration: 1, Model added: scarlet-sweep-159, Best weight: -0.18, Best Thr: -0.0699996173381805, Best F1: 0.93274
=== 0 ===
=== 1 ===
=== 2 ===
=== 3 ===
=== 4 ===
=== 5 ===
=== 6 ===
=== 7 ===
=== 8 ===
=== 9 ===
=== 10 ===
=== 11 ===
=== 12 ===
=== 13 ===
=== 14 ===
=== 15 ===
=== 16 ===
Iteration: 2, Model added: winter-sweep-24, Best weight: -0.16, Best Thr: -0.09999963343143459, Best F1: 0.93418
=== 0 ===
=== 1 ===
=== 2 ===
=== 3 ===
=== 4 ===
=== 5 ===
=== 6 ===
=== 7 ===
=== 8 ===
=== 9 ===
=== 10 ===


In [None]:
torch.manual_seed(42)
# Init W&B
run = wandb.init(entity="javigallego4", project="Bachelor Thesis", group = 'Baseline - Ensembling')

# Preprocessing Pipeline (No initial scaling)
p_images, m_images, p_masks, m_masks = [], [], [], []
load_images(polygon_numbers)
m_images, p_masks = preprocessing_pipeline('IHS3', cv2.BORDER_CONSTANT, False)

# Datasets and train-test-split
multispectral_dataset = MultispectralDataset(m_images, p_masks)
n = len(multispectral_dataset)
validation_data = torch.utils.data.Subset(multispectral_dataset, X_val)
gc_collect()
validationloader = torch.utils.data.DataLoader(validation_data, batch_size = 16, num_workers = 2, pin_memory=True, shuffle=False)
print('Dataset created')

# Load artifact
predicted_masks = []
best_runs = ['vague-sweep-19', 'apricot-sweep-60','trim-sweep-18','curious-sweep-71','royal-sweep-53','decent-sweep-68', 'super-sweep-62','elated-sweep-58']
n_filters_1 = [96, 64, 96, 128, 32, 32, 128, 32]
n_filters_2 = [512, 96, 32, 128, 128, 32, 16, 32]
n_filters_3 = [64, 512, 512, 512, 256, 96, 64, 64]
n_filters_4 = [32, 32, 96, 256, 96, 32, 512, 512]

#best_thr = [-0.1547, -0.1433, -0.1027]
for i, run_name in enumerate(best_runs):
    # Load artifact
    artifact = run.use_artifact('javigallego4/Bachelor Thesis/{}:v0'.format(run_name), type='model')
    artifact_dir = artifact.download()

    # Load the model
    model = multispectralUnet(n_filters_1[i], n_filters_2[i], n_filters_3[i], n_filters_4[i]).to(device)
    model.load_state_dict(torch.load('/content/artifacts/{}:v0/{}.pt'.format(run_name, run_name)))

    # Evaluate
    predicted_masks.append(get_validation_masks(validationloader, model))
    print('Predicted masks appended')

ground_truth_masks = []
for imgs, masks in tqdm(validationloader):
    ground_truth_masks.append(masks.squeeze().to(device))

ground_truth_masks = torch.cat(ground_truth_masks, dim = 0)[:,29:,:]
hill_climbing(predicted_masks, ground_truth_masks, best_runs)