## CNN model for time-series forcasting Oa08 reflectance band for sea surface colour dataset from Sentinel 3a 
### Author: Smita Chakraborty, RISE

In [1]:
##libraries for ML tasks
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, LSTM, TimeDistributed, Reshape
from tensorflow.keras.models import Model
from matplotlib import pyplot as plt
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import datetime
import tifffile as tiff
from tensordict import TensorDict
import pandas as pd

## libraries for s3 bucket connection: read input and write output
import sys
sys.path.append('../.')
import boto3
from dotenv import load_dotenv
import os
from utils import boto3_connect
from itertools import product
import rasterio as rio
from rasterio import windows
from io import BytesIO

2024-10-23 12:58:28.594512: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-23 12:58:28.613364: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-23 12:58:28.619001: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-23 12:58:28.633641: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


ModuleNotFoundError: No module named 'eumartools'

In [None]:
#read tiff files
def read_tiff(img, outx):

    # Convert to grayscale if it's not already
    #if img.mode != 'L':
    #   img = img.convert('L')

    # Convert to numpy array
    #img_array = np.array(img)
    
    #rasterio to read pixel values, 1000 x 1000 array
    pix_val = img.read(1) 
    
    #returns date
    #so im.tags()["date"]
    
    #skip nan tiffs
    if np.all(np.isnan(pix_val)) is False:

        # Normalize to [0, 1] range
        norm = (pix_val - np.nanmin(pix_val))/(np.nanmax(pix_val) - np.nanmin(pix_val))
        #img_array = img_array.astype(np.float32) / 255 #change to x-min/max-min

        # adds a channel dimension to the array(shape becomes [1, height, width, norm_pix_val]).
        #need to add norm
        #img_array = np.expand_dims(pix_val, axis=0)

        # Convert to PyTorch tensor
        #tensor = torch.from_numpy(pix_val) #img_array

        # Adds a batch dimension to the tensor, making its shape [1, 1, height, width]
        #tensor = tensor.unsqueeze(0)

        outx.append(norm)
    
    return outx
        #print(outx)


In [None]:
## login credentials to DEDL platform from the .env file in parent directory
load_dotenv()
USERNAME=os.getenv('DEDL_USERNAME')
PASSWORD=os.getenv('DEDL_PASSWORD')
ACCESS_KEY=os.getenv('S3_ACCESS_KEY')
SECRET_KEY=os.getenv('S3_SECRET_KEY')

s3=boto3_connect(ACCESS_KEY, SECRET_KEY)

s3buc = s3.Bucket('algaestorm')

In [None]:
#for obj in s3buc.objects.filter(Prefix="eodata/split_img/"):
#    print(obj.key)

In [None]:
#creates a list of tensors for the input conv2d layer
outx = []

for obj in s3buc.objects.filter(Prefix="eodata/split_img/S3A_OL_2_WFR____20190602T084749_20190602T085049_20190603T201514_0179_045_221_1980_MAR_O_NT_002.SEN3/"): #S3A_OL_2_WFR____20190603T082138_20190603T082438_20190604T171448_0179_045_235_1980_MAR_O_NT_002.SEN3/
    #loops through all tiffs in 1 day
    ##connecting to s3 bucket repository 
    #s3obj=s3.Object(bucket_name='algaestorm',key='https://cloud.central.data.destination-earth.eu/project/containers/container/algaestorm/eodata/Sentinel-3/OLCI/OL_2_WFR/2019/04/')
    body = obj.get()['Body'].read()
    filelike=BytesIO(body)
    with rio.open(filelike, mode='r') as im:

        date=im.tags()["date"] 
        
        #check dimension of the file
        x_dim = im.width
        y_dim = im.height
        
        val = im.read(1)
        
        chk = ~np.all(np.isnan(val))
        #print(np.all(np.isnan(val)), chk)
        
        if (x_dim, y_dim) == (1000, 1000) and chk:
            
            # Normalize to [0, 1] range
            norm = (val - np.nanmin(val))/(np.nanmax(val) - np.nanmin(val))
            
            #read tiff files and builds a list of tensors
            #outx = read_tiff(im, outx)
            
            #print(norm, x_dim, y_dim)
            
            #Convert to PyTorch tensor 
            tensor = torch.from_numpy(norm) 
            
            # Adds a batch dimension to the tensor, making its shape [1, 1, height, width]
            #tensor = tensor.unsqueeze(0)
            
            # Adds a channel dimension to the tensor, making its shape [1, height, width, 1]
            tensor = tensor.unsqueeze(2)
            
            outx.append(tensor)
                                                                                                                                                                            

#print(outx)

In [None]:
#create a tensor of the list of tensors
result = torch.stack(outx, 0)

In [None]:
#check the shape of the output tensor, should have the shape:
# (batch size, img_height, img_width, channels) as conv2d accepts 4D tensor
tf.shape(result)

In [None]:
#creates a list of tensors for the input conv2d layer for a second date
outx_2 = []

##connecting to s3 bucket repository 
for obj in s3buc.objects.filter(Prefix="eodata/split_img/S3A_OL_2_WFR____20190603T082138_20190603T082438_20190604T171448_0179_045_235_1980_MAR_O_NT_002.SEN3/"): #S3A_OL_2_WFR____20190603T082138_20190603T082438_20190604T171448_0179_045_235_1980_MAR_O_NT_002.SEN3/
    #loops through all tiffs in 1 day
    #s3obj=s3.Object(bucket_name='algaestorm',key='https://cloud.central.data.destination-earth.eu/project/containers/container/algaestorm/eodata/Sentinel-3/OLCI/OL_2_WFR/2019/04/')
    body = obj.get()['Body'].read()
    filelike=BytesIO(body)
    with rio.open(filelike, mode='r') as im:

        #date=im.tags()["date"] 
        
        #check dimension of the file
        x_dim = im.width
        y_dim = im.height
        
        val = im.read(1)
        
        chk = ~np.all(np.isnan(val))
        #print(np.all(np.isnan(val)), chk)
        
        if (x_dim, y_dim) == (1000, 1000) and chk:
            
            # Normalize to [0, 1] range
            norm = (val - np.nanmin(val))/(np.nanmax(val) - np.nanmin(val))
            
            #read tiff files and builds a list of tensors
            #outx = read_tiff(im, outx)
            
            #print(norm, x_dim, y_dim)
            
            #Convert to PyTorch tensor
            tensor = torch.from_numpy(norm) 
            
            # Adds a batch dimension to the tensor, making its shape [1, 1, height, width]
            #tensor = tensor.unsqueeze(0)
            
            # Adds a channel dimension to the tensor, making its shape [1, height, width, channel]
            tensor = tensor.unsqueeze(2)
            
            outx_2.append(tensor)
                                                                                                                                                                            

#print(outx_2)

In [None]:
#create a tensor of the list of tensors
result2 = torch.stack(outx_2, 0)

In [None]:
#check the shape of the output tensor, should have the shape:
# (batch size, img_height, img_width, channels) as conv2d accepts 4D tensor
tf.shape(result2)

In [None]:
#This aproach is discouraged as it is inefficient due to the size of the flattened tensor

    #Define the model with a Dense layer
    #rescaling and conv2D as this exceeds memory
#model = tf.keras.Sequential([
 #   layers.Flatten(input_shape=result.shape),
 #   layers.Dense(units=1000000, activation='linear')
#])


In [None]:
# replace nan values with zero for now, 
# nan hinders the model from learning even if some pixel values are non-zero
def replace_nan_with_zero(tensor):
    return tf.where(tf.math.is_nan(tensor), tf.zeros_like(tensor), tensor)

result = replace_nan_with_zero(result)
result2 = replace_nan_with_zero(result2)

In [None]:
# Slice the tensor to reduce batch size
reduced_result = result[:6, :, :, :]

# Split the tensor into two parts along the batch dimension if want to use both/all parts of a tensor from 
# a single day
#part1, part2 = tf.split(result, num_or_size_splits=2, axis=0)

In [None]:
#the main conv2D layers
img_height= 1000
img_width= 1000
model = tf.keras.Sequential([
        layers.Conv2D(16, (3, 3), activation='relu', padding='same'), #16 filters, 3 x 3 kernel size
        layers.Conv2D(1, (1, 1), activation='relu', padding='same'), #1 filter, 1 x 1 kernel size
    ])

In [None]:
# Compile the model
model.compile(optimizer='adam', loss='mean_squared_error')

# Predict pixel values
predicted_values = model.predict(reduced_result)

#print(predicted_values)

#prints shape of predicted tensor
print(predicted_values.shape)

In [None]:
#creates a new list of tensors from a third date
pred = []

for obj in s3buc.objects.filter(Prefix="eodata/split_img/S3A_OL_2_WFR____20190605T091015_20190605T091315_20190606T200225_0179_045_264_1980_MAR_O_NT_002.SEN3/"): #S3A_OL_2_WFR____20190603T082138_20190603T082438_20190604T171448_0179_045_235_1980_MAR_O_NT_002.SEN3/
    #loops through all tiffs in 1 day
    body = obj.get()['Body'].read()
    filelike=BytesIO(body)
    with rio.open(filelike, mode='r') as im:

        date=im.tags()["date"] 
        
        #check dimension of the file
        x_dim = im.width
        y_dim = im.height
        
        val = im.read(1)
        
        chk = ~np.all(np.isnan(val))
        #print(np.all(np.isnan(val)), chk)
        print(np.nanmin(val))
        
        if (x_dim, y_dim) == (1000, 1000) and chk:
            
            # Normalize to [0, 1] range
            p_norm = (val - np.nanmin(val))/(np.nanmax(val) - np.nanmin(val))
            
            #read tiff files and builds a list of tensors
            #outx = read_tiff(im, outx)
            
            #print(norm, x_dim, y_dim)
            p_tensor = torch.from_numpy(p_norm) 
            
            # Adds a batch dimension to the tensor, making its shape [1, 1, height, width]
            #tensor = tensor.unsqueeze(0)
            
            # Adds a channel dimension to the tensor, making its shape [1, height, width, 1]
            p_tensor = p_tensor.unsqueeze(2)
            
            pred.append(p_tensor)
                                                                                                                                                                            

#print(pred)

In [None]:
#create a tensor of the list of tensors
pred_result = torch.stack(pred, 0)

# Slice the tensor to reduce batch size
reduced_predresult = pred_result[:6, :, :, :]

# Split the tensor into two parts along the batch dimension
#part1, part2 = tf.split(pred_result, num_or_size_splits=2, axis=0)

print(pred_result.shape) #original shape of the target tensor
print(reduced_predresult.shape) #modified shape of the target tensor

In [None]:
def replace_nan_with_zero(tensor):
    return tf.where(tf.math.is_nan(tensor), tf.zeros_like(tensor), tensor)

reduced_predresult = replace_nan_with_zero(reduced_predresult)

def count_nans(tensor):
    nan_mask = tf.math.is_nan(tensor)
    nan_count = tf.reduce_sum(tf.cast(nan_mask, tf.int32))
    return nan_count

nan_count_1 = count_nans(predicted_values)
nan_count_2 = count_nans(reduced_predresult)

print("Number of nan in pred:", nan_count_1) #returns number of nan in the predicted tensor, should be zero
print("Number of nan in target:", nan_count_2) #returns number of nan in the nan-removed target tensor, should be zero

In [None]:
#calculates root mean squared error based on predicted tensor and unseen tensor
rmse = tf.keras.metrics.RootMeanSquaredError()
rmse.update_state(predicted_values, reduced_predresult)
rmse.result()

### Training loop and plotting the loss function

In [None]:
#'model' is our CNN model and 'result' and 'result2' are our tensors

# Compile the model
model.compile(optimizer='adam', loss='mean_squared_error')

# Define the number of epochs
epochs = 30

# Initialize a list to store loss values
loss_history = []

# Training loop
for epoch in range(epochs):
    # Train on 'result' tensor
    history = model.fit(reduced_result, reduced_result, epochs=epochs, verbose=0)
    loss_history.append(history.history['loss'][0])
    
    # Train on 'result2' tensor
    #history = model.fit(result2, result2, epochs=epochs, verbose=0)
    #loss_history.append(history.history['loss'][0])
    
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss_history[-1]:.8f}")

# Predict pixel values using the trained model
predicted_values_lp = model.predict(reduced_result)

# Print final loss
print(f"Final loss: {loss_history[-1]:.8f}")

# You can plot the loss history if desired
import matplotlib.pyplot as plt
plt.plot(loss_history)
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
# Save the image using matplotlib.pyplot.imsave()
plt.savefig('modelOa08_loss_result1_1_30.png')
plt.show()

In [None]:
# calculates root mean squared error based on predicted tensor and unseen tensor
rmse = tf.keras.metrics.RootMeanSquaredError()
rmse.update_state(predicted_values_lp, reduced_predresult) #reduced_predresult is our target
rmse.result()

Predicted_values_lp is the prediction

reduced_predresult is the target

reduced_result, result_2 are training tensors

## Plotting the results and the new data

### Plotting a SNS heatmap visualization

In [None]:
#plots a seaborn heatmap of the input tensor
import seaborn as sns

fig, axes = plt.subplots(2, 3, figsize=(20, 15))
fig.suptitle('Heatmaps of Tensor Slices', fontsize=16)

for i, ax in enumerate(axes.flatten()):
    sns.heatmap(reduced_result[i, :, :, 0], cmap='viridis', ax=ax, cbar=True)
    ax.set_title(f'Slice {i+1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to accommodate subtitle
plt.show()

In [None]:
#plots a seaborn heatmap of the predicted tensor
import seaborn as sns

fig, axes = plt.subplots(2, 3, figsize=(20, 15))
fig.suptitle('Heatmaps of Tensor Slices', fontsize=16)

for i, ax in enumerate(axes.flatten()):
    sns.heatmap(predicted_values_lp[i, :, :, 0], cmap='viridis', ax=ax, cbar=True)
    ax.set_title(f'Slice {i+1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to accommodate suptitle
plt.show()

In [None]:
#plots a seaborn heatmap of the output tensor scaled to 0.0012 x \pi = 0.003768 as the threshold value

fig, axes = plt.subplots(2, 3, figsize=(20, 15))
fig.suptitle('Heatmaps of Tensor Slices', fontsize=16)


for i, ax in enumerate(axes.flatten()):
    # Calculate the maximum value in the tensor slice
    vmax = np.max(predicted_values[i, :, :, 0])
    # Set vmin to 0.003768 and use the calculated vmax
    sns.heatmap(predicted_values_lp[i, :, :, 0], cmap='viridis', vmin=0.003768, vmax=vmax, ax=ax, cbar=True)
    #sns.heatmap(predicted_values[i, :, :, 0], cmap='viridis', ax=ax, cbar=True)
    ax.set_title(f'Slice {i+1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to accommodate suptitle
plt.show()

### Plots a 3D surface plot of the input \& predicted tensors


In [None]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(20, 15))

for i in range(6):
    ax = fig.add_subplot(2, 3, i+1, projection='3d')
    x, y = np.meshgrid(range(0, 1000, 10), range(0, 1000, 10))  # Downsampling for performance
    ax.plot_surface(x, y, reduced_result[i, ::10, ::10, 0], cmap='viridis')
    ax.set_title(f'3D Surface Plot of Slice {i+1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Value')

plt.tight_layout()
plt.show()

#fig = plt.figure(figsize=(12, 10))
#ax = fig.add_subplot(111, projection='3d')
#x, y = np.meshgrid(range(1000), range(1000))
#ax.plot_surface(x, y, result[0, :, :, 0], cmap='viridis')
#plt.title('3D Surface Plot of First Slice')
#plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(20, 12))

for i in range(6):
    ax = fig.add_subplot(2, 3, i+1, projection='3d')
    x, y = np.meshgrid(range(0, 1000, 10), range(0, 1000, 10))  # Downsampling for performance
    
    # Apply the lower cut value
    z = predicted_values_lp[i, ::10, ::10, 0]
    z = np.maximum(z, 0.00376)
    
    # Plot the surface
    surf = ax.plot_surface(x, y, z, cmap='viridis')
    
    ax.set_title(f'3D Surface Plot of Slice {i+1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Value')
    
    # Set the minimum z-axis value to 0.00376
    ax.set_zlim(bottom=0.00376)
    
    # Add a color bar
    fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5)

plt.tight_layout()
plt.show()

In [None]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(20, 15))

##Thresholds
##surface_TH = 0.0012 * np.pi (band for 665 nm)  ---> Oa08
##subsurface_TH = 0.00435 * np.pi * .9 (band for 560 nm) ---> Oa06


for i in range(6):
    ax = fig.add_subplot(2, 3, i+1, projection='3d')
    x, y = np.meshgrid(range(0, 1000, 10), range(0, 1000, 10))  # Downsampling for performance
    ax.plot_surface(x, y, predicted_values_lp[i, ::10, ::10, 0], cmap='viridis')
    ax.set_title(f'3D Surface Plot of Slice {i+1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Value')

plt.tight_layout()
plt.show()

In [None]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(20, 15))

##Thresholds
##surface_TH = 0.0012 * np.pi (band for 665 nm)  ---> Oa08
##subsurface_TH = 0.00435 * np.pi * .9 (band for 560 nm) ---> Oa06


for i in range(6):
    ax = fig.add_subplot(2, 3, i+1, projection='3d')
    x, y = np.meshgrid(range(0, 1000, 10), range(0, 1000, 10))  # Downsampling for performance
    ax.plot_surface(x, y, reduced_predresult[i, ::10, ::10, 0], cmap='viridis')
    ax.set_title(f'3D Surface Plot of Slice {i+1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Value')

plt.tight_layout()
plt.show()