In [None]:
labels = ['bathy-gebco', 'salinity3d', 'wave-height', 'surface-wind-u', 'surface-wind-v',
          'oxygen', 'ph', 'fsle', 'fsle-orientation', 'geos-current-u',
          'geos-current-v', 'eke', 'chlorophyll-occi', 'sst-mur', 'mixed-layer-thickness',
          'diatoms', 'dinophytes', 'haptophytes', 'green-algae', 'prochlorophytes',
          'prokaryotes', 'chlorophyll-occi-15', 'sst-mur-15', 'sst-mur-5', 'chlorophyll-occi-5',
          'Atlantic', 'Indian', 'Pacific', 'North hemisphere']

import random
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely import wkt
from pathlib import Path
import csv

from hydra import compose, initialize
from multi38 import predict, test, last_checkpoint, Multi38DataModule

from matplotlib import pyplot as plt
%matplotlib widget

natural_earth_path = Path(gpd.__file__).parent / 'datasets/naturalearth_lowres/naturalearth_lowres.shp'

data_path = Path('../data/')

# Indicate your checkpoint path and name here
ckpt_path = Path("../outputs/multi38/")
ckpt_name = '2023-05-22_11-12-13/checkpoint-epoch=09--val_f1=0.6592.ckpt'
ckpt_ref = ckpt_name.split('/')[0] + '_' 

initialize(config_path=ckpt_path / Path(ckpt_name).parent / 'logs/', version_base="1.1")
cfg = compose(config_name="hparams")
cfg.other.ckpt_name = ckpt_name

In [None]:
with (data_path / 'species.csv').open() as file:
    species = [r[1] for r in csv.reader(file)][1:]

with (data_path / 'species.csv').open() as file:
    species_ids = [r[0] for r in csv.reader(file)][1:]

# Make predictions on test data

In [None]:
import geopandas as gpd

df = pd.read_csv(data_path / cfg.data.dataset_name, index_col = 'id')
df = df.loc[df.index[df["subset"] == 'test']].reset_index()
df['geometry'] = df['geometry'].apply(wkt.loads)
gdf = gpd.GeoDataFrame(df, crs='epsg:4326')

groundtruth = np.array([species_ids.index(x) for x in list(df['species'])])

### Calculate predictions

In [None]:
all_predictions = predict(cfg)
predictions = np.argmax(all_predictions, axis = 1)

In [None]:
# Top 1 score

(groundtruth == predictions).sum() / groundtruth.shape[0]

In [None]:
# Top 10 Score

for i in list(range(1, 11)) + [38]:
    sec_ind = np.argpartition(all_predictions, -i, axis=1)[:,-i:]
    prob = (np.expand_dims(groundtruth,axis=-1) == sec_ind).sum() / groundtruth.shape[0]
    print(i, f"{prob:.2%}")

### Confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(groundtruth, predictions, labels=range(38), normalize = 'true')
cm = cm.round(2)

fig, ax = plt.subplots(figsize=(16, 16))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=species)
disp.plot(xticks_rotation = 'vertical', colorbar = False, text_kw={'fontsize': 'x-small'},ax=ax)
plt.tight_layout()
#plt.show()
#plt.savefig(data_path / f"../outputs/confusion_matrix/{ckpt_ref}cm.png")

# Make predictions for new data

### World

In [None]:
cfg.data.dataset_name = "datasets/world_tiled.csv"
all_predictions = predict(cfg)

### WIO

In [None]:
cfg.data.dataset_name = "datasets/wio-tiled.csv"
all_predictions = predict(cfg)

### Export to csv

In [None]:
df = pd.read_csv(data_path / cfg.data.dataset_name, index_col = 'id')

results = pd.DataFrame(all_predictions, columns=species, index = df.index)
full = df.merge(results, left_index=True, right_index=True)
full['best-species'] = np.argmax(all_predictions,axis=1)
full.to_csv(data_path / f"../outputs/world-predictions/{ckpt_ref}world_predictions.csv")

# Export outputs to images or rasters

In [None]:
import netCDF4 as nc

import csv
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.interpolate import griddata
import rasterio
import rasterio.mask
import geopandas as gpd
from rasterio.transform import from_origin
from cftime import date2num
from tqdm import tqdm

### Geotiff

In [None]:
input_csv_path = data_path / f"../outputs/wio-predictions/{ckpt_ref}wio_predictions.csv"
data_value_range = [0,1]
map_bounds = [20, 80, -60, 0]
res = 0.02
folder = Path(f"../outputs/wio-rasters/{ckpt_ref}rasters/")

# input_csv_path = data_path / f"../outputs/world-predictions/{ckpt_ref}world_predictions.csv"
# data_value_range = [0,1]
# map_bounds = [-180, 180, -90, 90]
# res = 0.1
# folder = Path(f"../outputs/world-rasters/{ckpt_ref}rasters/")

full = pd.read_csv(input_csv_path, index_col = 'id')
    

if not(folder.exists()):
    folder.mkdir()

for s in tqdm(species):
    
    ref = s.replace(' ','_')
    subfolder = folder / ref
    if not(subfolder.exists()):
        subfolder.mkdir()
    
    data_column_name = s
    dates = list(full['date'].unique())
    
    for d in dates:

        df = full[full['date'] == d]
        output_raster_path = subfolder / f"{ref}-{d}.tiff"
        
        # Define output grid dimensions
        lon_range, lat_range = map_bounds[1] - map_bounds[0], map_bounds[3] - map_bounds[2]
        xs = np.linspace(0.5, lon_range / res - 0.5, int(lon_range / res))
        ys = np.linspace(0.5, lat_range / res - 0.5, int(lat_range / res))
        X, Y = np.meshgrid(xs, ys)

        # Convert point data to grid coordinates
        x = ((df['lon'] - map_bounds[0]) / res).to_numpy(dtype=int)
        y = ((df['lat'] - map_bounds[2]) / res).to_numpy(dtype=int)
        values = df[data_column_name].to_numpy()

        # Interpolate to grid
        band = griddata((x, y), values, (X, Y), method = 'cubic', fill_value = -1)

        # Resample to 1-254 interval (0=nodata)
        normed = (band - data_value_range[0]) / (data_value_range[1] - data_value_range[0])
        data = np.floor(254*np.flip(np.clip([normed],0,1), axis=1))
        data = data.astype(np.uint8) + 1

        # Reproject to EPSG:4326 and save GeoTIFF
        transform = from_origin(map_bounds[0], map_bounds[3], res, res)
        dst = rasterio.open(output_raster_path, 'w', driver='GTiff',
                            height = data.shape[1], width = data.shape[2],
                            dtype=str(data.dtype),
                            count=1,
                            crs='epsg:4326',
                            transform=transform,
                            nodata=0,
                            compress='lzw')

        dst.write(data)
        dst.close()

        # Mask continents
        continents = gpd.read_file(natural_earth_path).unary_union

        with rasterio.open(output_raster_path, driver='GTiff') as src:
            out_image, out_transform = rasterio.mask.mask(src, [continents], invert=True)

        with rasterio.open(output_raster_path, 'r+', driver='GTiff') as dst:
            dst.transform = out_transform
            dst.write(out_image)


### PNG

In [None]:
# Export rasters as PNG

import cv2
from matplotlib import cm

in_folder = Path(f"../outputs/wio-rasters/{ckpt_ref}rasters/")
out_folder = Path(f"../outputs/wio-png/{ckpt_ref}png/")

if not(out_folder.exists()):
    out_folder.mkdir()

for s in tqdm(species):
    
    ref = s.replace(' ','_')
    in_subfolder = in_folder / ref
    out_subfolder = out_folder / ref
    
    if not(out_subfolder.exists()):
        out_subfolder.mkdir()
    
    for f in in_subfolder.glob('*.tiff'):

        out_file = Path(out_subfolder, f.stem + '.png')

        with rasterio.open(f, driver='GTiff') as src:
            
            data = src.read(1)
            scaled = (np.float32(data) -1) / 254
            scaled[scaled<0] = np.nan       
          
            im2 = getattr(cm, 'Blues')(scaled)
            im2[np.isnan(scaled)] =  np.array([0,0,0,1])
            im3 = cv2.cvtColor(np.float32(255*im2), cv2.COLOR_BGRA2RGBA)

            ## Add date
            h, w, _ = im3.shape
            date_index = f.stem.find('2021')
            date = f.stem[date_index:date_index+10]
            im3 = cv2.putText(im3, date, (int(0.02*w), int(0.06*h)), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale=4, color=(255,255,255,255), thickness=10)
            
            cv2.imwrite(str(out_file), im3)

### Make gifs

In [None]:
for s in tqdm(species):
    ref = s.replace(' ','_')
    !convert -resize 800x800 -delay 30 -loop 0 ../outputs/wio-png/{ckpt_ref}png/{ref}/*.png ../outputs/wio-gifs/{ckpt_ref}gifs/{ref}.gif

### Make figure for Prionace glauca

In [None]:
s = 'Prionace glauca'
ref = s.replace(' ','_')
folder = Path(f"../outputs/wio-png/{ckpt_ref}png") / ref


fig, axes = plt.subplots(5, 4, figsize = (16, 20))
i = 0

for f in sorted(folder.glob('*.png')):
    if i % 3:
        j = i // 3
        ax = axes[j // 4, j % 4]
        image = plt.imread(f)
        ax.axis('off')
        ax.imshow(image)
    i += 1

axes[4,2].axis('off')
axes[4,3].axis('off')

fig.tight_layout()
plt.savefig(Path(f"../outputs/") / "WIO_prionace_grid")

# Interpretation

In [None]:
# Initial imports
from scipy import stats
import torch
from multi38 import *
import seaborn as sns
from tqdm import tqdm
from random import sample
import pickle

from captum.attr import IntegratedGradients

import matplotlib.pyplot as plt
%matplotlib widget

meds, perc1, perc99 = np.load(data_path / "stats.npy")
meds = np.expand_dims(np.expand_dims(meds,axis=-1),axis=-1)


## Calculate integrated gradients

In [None]:
ckpt_path = cfg.other.ckpt_path + cfg.other.ckpt_name
cfg.data.dataset_name = data_path.parent / f"outputs/world-predictions/{ckpt_ref}world_predictions.csv"
cfg.data.inference_batch_size = 1

model = ClassificationSystem.load_from_checkpoint(ckpt_path, model=cfg.model, **cfg.optimizer)
model.eval()

datamodule = Multi38DataModule(**cfg.data)
datamodule.setup(stage='predict')
ds = datamodule.get_dataset('test', datamodule.test_transform)
ig = IntegratedGradients(model)

## Over whole dataset

In [None]:
full = pd.read_csv(data_path / f"../outputs/world-predictions/{ckpt_ref}world_predictions.csv", index_col = 'id')
all_obs_ids = list(ds.observation_ids)

out_folder = Path(f"../outputs/interpretation/{ckpt_ref}/")

if not(out_folder.exists()):
    out_folder.mkdir()

glob_l = []
    
l = []
for k in sample(range(len(all_obs_ids)), 1000):

    # Calculate integrated gradient and average on the tile
    x = ds[k][0]
    x.requires_grad_()
    target = int(full.loc[all_obs_ids[k], 'best-species'])
    attr = ig.attribute(x.unsqueeze(0), target = target).detach().numpy().squeeze()
    abs_attr = np.abs(attr)
    averages = abs_attr.mean(axis=(1,2))
    glob_l.append(averages)

# Save statistics to csv file

glob_table = np.vstack(glob_l)
glob_df = pd.DataFrame(glob_table)
glob_df.columns = labels
glob_df.describe().T.to_csv(out_folder / "All.csv")

## By species

In [None]:
import seaborn as sns
from tqdm import tqdm
from random import sample

full = pd.read_csv(data_path / f"../outputs/world-predictions/{ckpt_ref}world_predictions.csv", index_col = 'id')
all_obs_ids = list(ds.observation_ids)

out_folder = Path(f"../outputs/interpretation/{ckpt_ref}/")

if not(out_folder.exists()):
    out_folder.mkdir()

for i in tqdm(range(38)):
    s = species[i]
    subdf = full[full['best-species'] == i]
    obs_ids = subdf.index
    obs_ids_index = [all_obs_ids.index(obs_id) for obs_id in obs_ids]
    
    l = []
    
    for k in sample(obs_ids_index, min(1000,len(subdf))):

        # Calculate integrated gradient and average on the tile
        x = ds[k][0]
        x.requires_grad_()
        target = int(full.loc[all_obs_ids[k], 'best-species'])
        attr = ig.attribute(x.unsqueeze(0), target = target).detach().numpy().squeeze()
        abs_attr = np.abs(attr)
        averages = abs_attr.mean(axis=(1,2))
        l.append(averages)

    
    if len(l):

        # Save statistics to csv files
    
        table = np.vstack(l)
        df = pd.DataFrame(table)
        df.columns  = labels
        df.describe().T.to_csv(out_folder / f"{s}.csv")


## Plot results by species

In [None]:
columns = ['bathymetry', 'salinity', 'wave height', 'surface wind (u)', 'surface wind (v)',
          'oxygen', 'pH', 'fsle (strength)', 'fsle (orientation)', 'geos current (u)',
          'geos current (v)', 'chlorophyll', 'sea surface temperature', 'mixed layer thickness',
          'diatoms', 'dinophytes', 'haptophytes', 'green algae', 'prochlorophytes',
          'prokaryotes', 'Atlantic Ocean', 'Indian Ocean', 'Pacific Ocean', 'North hemisphere']

In [None]:
import seaborn as sns

# Calculate top variables

summa = pd.read_csv(f"../outputs/interpretation/{ckpt_ref}/All.csv", index_col = 0)
topn = list(summa.sort_values('50%', ascending = False).index[:8])
print(', '.join(topn))

# Create chart

series_dict = {}

for i in range(38):
    s = species[i]
    p = Path(f"../outputs/interpretation/{ckpt_ref}/{s}.csv")
    if p.exists():
        df = pd.read_csv(p, index_col = 0)
        series_dict[s] = df['50%']

df = pd.concat(series_dict, axis = 1).T
df.drop(columns = ['eke', 'chlorophyll-occi-15', 'sst-mur-15', 'sst-mur-5', 'chlorophyll-occi-5'], inplace = True)

fig, ax = plt.subplots(figsize=(10,12))
sns.heatmap(df, ax = ax, cbar=False, cmap='YlOrRd')
plt.tight_layout()

# Compare +2°C

In [None]:
import cv2
from matplotlib import cm

in_folder0 = Path(f"../outputs/world-rasters/{ckpt_ref}rasters/")
in_folder2 = Path(f"../outputs/world-rasters+2/{ckpt_ref}rasters/")
out_folder = Path(f"../outputs/world-png+2/{ckpt_ref}png/")

if not(out_folder.exists()):
    out_folder.mkdir()

for s in tqdm(species):
    
    ref = s.replace(' ','_')
    in_subfolder0 = in_folder0 / ref
    in_subfolder2 = in_folder2 / ref
    out_subfolder = out_folder / ref
    
    if not(out_subfolder.exists()):
        out_subfolder.mkdir()
    
    for f in in_subfolder0.glob('*.tiff'):

        out_file = Path(out_subfolder, f.stem + '.png')

        with rasterio.open(in_subfolder0 / f.name, driver='GTiff') as src0:
            with rasterio.open(in_subfolder2 / f.name, driver='GTiff') as src2:       
                data0 = np.float32(src0.read(1))
                data2 = np.float32(src2.read(1))
                
                data0[data0==0] = np.nan  
                data2[data2==0] = np.nan  

                diff = (data2 - data0) / 253 + 0.5

                im2 = getattr(cm, 'PiYG')(diff)
                im2[np.isnan(diff)] =  np.array([0,0,0,1])
                im3 = cv2.cvtColor(np.float32(255*im2), cv2.COLOR_BGRA2RGBA)
            
                cv2.imwrite(str(out_file), im3)