In [1]:
import pandas as pd

import numpy as np

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plot

from functools import partial

AUTOTUNE = tf.data.experimental.AUTOTUNE
n_steps = 8 # Number of time steps
n_lat = 10
n_lon = 6
n_bands = 12
n_depths = 3
n_windspeed = 10
n_winddir = 12
n_currents_var = 3
n_TS_var = 2

In [48]:
def parse_tfrecord_fn(example_proto):
    # Create a dictionary describing the features.
    feature_description = {
                  'u': tf.io.FixedLenFeature([], tf.string),
                  'v': tf.io.FixedLenFeature([], tf.string),
                  'w': tf.io.FixedLenFeature([], tf.string),
                  'rhos': tf.io.FixedLenFeature([], tf.string),
                  'T': tf.io.FixedLenFeature([], tf.string),
                  'S': tf.io.FixedLenFeature([], tf.string),
                  'wind': tf.io.FixedLenFeature([], tf.string),
                  'total_depth': tf.io.FixedLenFeature([], tf.float32),
                  'sample_depth': tf.io.FixedLenFeature([], tf.float32),
                  'julian_day': tf.io.FixedLenFeature([], tf.int64),
                  'year': tf.io.FixedLenFeature([], tf.int64),
                  'DO': tf.io.FixedLenFeature([], tf.float32),
                }
    
    example = tf.io.parse_single_example(example_proto, feature_description)
    
    example['rhos'] = tf.io.parse_tensor(example['rhos'], tf.float64)
    example['u'] = tf.io.parse_tensor(example['u'], tf.float64)
    example['v'] = tf.io.parse_tensor(example['v'], tf.float64)
    example['w'] = tf.io.parse_tensor(example['w'], tf.float64)
    example['T'] = tf.io.parse_tensor(example['T'], tf.float64)
    example['S'] = tf.io.parse_tensor(example['S'], tf.float64)
    example['wind'] = tf.io.parse_tensor(example['wind'], tf.float64)
    
    return example

def replacenan(t):
    return tf.where(tf.math.is_nan(t), tf.zeros_like(t), t)

def compute_mask(t):
    # get mask of data, 1 for valid, 0 for nan
    return ~tf.math.is_nan(t)

# Define dataset helper functions
def prepare_sample(features):
    
    rhos = tf.transpose(features['rhos'], perm=[3,0,1,2]) # lat,lon,bands,step --> step,lat,lon,bands
    # Use 12 modis bands
    rhos = tf.concat((rhos[:,:,:,:11], rhos[:,:,:,-4:-3]), axis=-1) # remove swir and 859 bands
    rhos = tf.expand_dims(rhos, axis=-1)
    rhos = tf.reverse(rhos, [0]) # reverse the step dim to make early obs first in step
    rhos = replacenan(rhos)
    
    T = tf.transpose(features['T'], perm=[3,0,1,2]) # lat,lon,depth,step --> step,lat,lon,depth
    T = tf.expand_dims(T, axis=-1)
    T = tf.reverse(T, [0])
    
    S = tf.transpose(features['S'], perm=[3,0,1,2]) 
    S = tf.expand_dims(S, axis=-1)
    S = tf.reverse(S, [0])
    
    TS = tf.concat([T,S], axis=-2)
    TS = replacenan(TS)
    
    u = tf.transpose(features['u'], perm=[3,0,1,2])
    u = tf.expand_dims(u, axis=-1)
    u = tf.reverse(u, [0])
    
    v = tf.transpose(features['v'], perm=[3,0,1,2])
    v = tf.expand_dims(v, axis=-1)
    v = tf.reverse(v, [0])
    
    w = tf.transpose(features['w'], perm=[3,0,1,2])
    w = tf.expand_dims(w, axis=-1)
    w = tf.reverse(w, [0])
    
    uvw = replacenan(tf.concat([u,v,w], axis=-2))
    
    wind = tf.transpose(features['wind'], perm=[2,0,1]) # n_windspeed,n_winddir,n_steps --> n_steps,n_windspeed,n_winddir
    wind = tf.reverse(wind, [0]) # steps,windspeed,winddir
    
    jd = tf.cast(features['julian_day'], tf.float32)
    orb_phase = tf.math.sin(tf.constant(2*np.pi/365.25)*jd) # not working better
    jd_encodings = [tf.math.sin(orb_phase),tf.math.cos(orb_phase)]
    
    anci_data = [features['sample_depth'],features['total_depth'],jd]

    return ({'reflectance': rhos, 'temperature_salinity': TS, 'currents': uvw, 'wind': wind, 'ancillary': anci_data, 'jd':jd_encodings},
        {'do': features['DO']})

In [49]:
def get_dataset(filenames, batch_size):
    
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)  # automatically interleaves reads from multiple files
        
    # parse_tfrecord
    dataset = dataset.map(partial(parse_tfrecord_fn), num_parallel_calls=AUTOTUNE)
    
    # prepare_example
    dataset = dataset.map(prepare_sample, num_parallel_calls=AUTOTUNE)
    
    dataset = dataset.batch(batch_size)

    return dataset

In [50]:
# Calculate HypoxAI predicted DO

# Load data
# Testset path
test_fnames = tf.io.gfile.glob('path_to_test/*.tfrecords')

# Load testset
test_dataset = get_dataset(test_fnames, 64)

# Load model
path_to_model = 'path_to_model/ts_uvw_wind_trial_3'
model = tf.keras.models.load_model(path_to_model)

# Apply model
test_predicted = model.predict(test_dataset)

In [17]:
# Get field measured DO

DO = []
for element in test_dataset.as_numpy_iterator(): 
    DO.append(element[1]['do'])

test_measured = np.concatenate(DO)