# Training a simple CNN model in Tensorflow for Tornado Detection

This notebook steps through how to train a simple CNN model using a subset of TorNet.

This will not produce a model with any skill, but simply provides a working end-to-end example of how to set up a data loader, build, and fit a model


In [None]:
import sys
# Uncomment if tornet isn't installed in your environment or in your path already
#sys.path.append('../')  

import os
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tornet.data.tf.loader import create_tf_dataset 
from tornet.data.constants import ALL_VARIABLES

In [None]:
# Create basic dataloader
# This option loads directly from netcdf files, and will be slow and IO bound
# To speed up training, either
#     build as a tensorflow_dataset , (see tornet/data/tfds/tornet/README.md)
#     cache dataset first , or
#     use tf.data.Dataset.load on a pre-saved dataset

# Location of tornet
data_root = os.environ['TORNET_ROOT']

# Get training data from 2018
data_type='train'
years = [2018,]

catalog_path = os.path.join(data_root,'catalog.csv')
if not os.path.exists(catalog_path):
    raise RuntimeError('Unable to find catalog.csv at '+data_root)
        
catalog = pd.read_csv(catalog_path,parse_dates=['start_time','end_time'])
catalog = catalog[catalog['type']==data_type]
catalog = catalog[catalog.start_time.dt.year.isin(years)]
catalog = catalog.sample(frac=1,random_state=1234)
file_list = [os.path.join(data_root,f) for f in catalog.filename]

ds = create_tf_dataset(file_list,variables=ALL_VARIABLES,n_frames=1) 

# (Optional) Save data for faster reloads (makes copy of data!)
# ds.save('tornet_sample.tfdataset') 


In [None]:
# If saved with ds.save(...), just load that model
#ds = tf.data.Dataset.load('tornet_sample.tfdataset')

In [None]:
# If data was registered in tensorflow_dataset, use that
# env variable TFDS_DATA_DIR should point to location of this resaved dataset
#import tensorflow_datasets as tfds
#import tornet.data.tfds.tornet.tornet_dataset_builder # registers 'tornet'

#data_type='train'
#years = [2018,]
#ds = tfds.load('tornet',split='+'.join(['%s-%d' % (data_type,y) for y in years]))

In [None]:
import tornet.data.preprocess as pp
from tornet.data import preprocess as tfpp

# Preprocess

# add 'coordinates' variable used by CoordConv layers
ds = ds.map(lambda d: pp.add_coordinates(d,include_az=False,backend=tf))
     
# Take only last time frame
ds = ds.map(pp.remove_time_dim)

# Split sample into inputs,label
ds = ds.map(tfpp.split_x_y)

# (Optional) add sample weights
# weights={'wN':1.0,'w0':1.0,'w1':1.0,'w2':2.0,'wW':0.5}
# ds = ds.map(lambda x,y:  tfpp.compute_sample_weight(x,y,**weights) )

ds = ds.prefetch(tf.data.AUTOTUNE)
        
ds = ds.batch(32)



In [None]:
# Create a simple CNN model
# This normalizes data, concatenates along channel, and applies a Conv2D
import keras
from tornet.data.constants import CHANNEL_MIN_MAX

input_vars = ALL_VARIABLES # which variables to use

# TF convention is B,L,W,H
inputs = {v:keras.Input(shape=(120,240,2),name=v) for v in input_vars}

# Normalize inputs
norm_layers = []
for v in input_vars:
    min_max = np.array(CHANNEL_MIN_MAX[v]) # [2,]

    # choose mean,var to get approximate [-1,1] scaling
    var=((min_max[1]-min_max[0])/2)**2 # scalar
    var=np.array(2*[var,])    # [n_sweeps,]
    offset=(min_max[0]+min_max[1])/2    # scalar
    offset=np.array(2*[offset,]) # [n_sweeps,]
    
    norm_layers.append(
        keras.layers.Normalization(mean=offset, variance=var,
                                   name='Normalized_%s' % v)
    )

# Concatenate normed inputs along channel dimension
x=keras.layers.Concatenate(axis=-1,name='Concatenate1')(
        [l(inputs[v]) for l,v in zip(norm_layers,input_vars)]
        )

# Replace background (nan) with -3
x=keras.layers.Lambda(lambda x: tf.where(tf.math.is_nan(x),-3.0,x),name='ReplaceNan')(x)

# Processing
x = keras.layers.Conv2D(32,3,padding='same',activation='relu')(x)
# add more..
x = keras.layers.Conv2D(1,1,padding='same',activation='relu',name='TornadoLikelihood')(x)
y = keras.layers.GlobalMaxPool2D(name='GlobalMaxPool')(x)

model = keras.Model(inputs=inputs,outputs=y,name='TornadoDetector')

model.summary()

In [None]:
# Compile
opt  = keras.optimizers.Adam(learning_rate=1e-3)
loss=keras.losses.BinaryCrossentropy(from_logits=True)
model.compile(loss=loss, optimizer=opt)

In [None]:
# Train


# steps_per_epoch=10 for demo purposes
model.fit(ds,epochs=3,steps_per_epoch=10)

In [None]:
# Build a test set
# Basic loader
data_type='test'
years = [2018]

catalog_path = os.path.join(data_root,'catalog.csv')
if not os.path.exists(catalog_path):
    raise RuntimeError('Unable to find catalog.csv at '+data_root)
        
catalog = pd.read_csv(catalog_path,parse_dates=['start_time','end_time'])
catalog = catalog[catalog['type']==data_type]
catalog = catalog[catalog.start_time.dt.year.isin(years)]
catalog = catalog.sample(frac=1,random_state=1234)
file_list = [os.path.join(data_root,f) for f in catalog.filename]

ds_test = create_tf_dataset(file_list,variables=ALL_VARIABLES,n_frames=1) 


In [None]:
# TFDS loader
# env variable TFDS_DATA_DIR should point to location of resaved dataset
#import tensorflow_datasets as tfds
#import tornet.data.tfds.tornet.tornet_dataset_builder # registers 'tornet'

#data_type='test'
#years = [2018,]
#ds_test = tfds.load('tornet',split='+'.join(['%s-%d' % (data_type,y) for y in years]))

In [None]:
# preprocess
ds_test = ds_test.map(lambda d: pp.add_coordinates(d,include_az=False,backend=tf))
ds_test = ds_test.map(pp.remove_time_dim)
ds_test = ds_test.map(tfpp.split_x_y)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)    
ds_test = ds_test.batch(32)


In [None]:
# Evaluate
import tornet.metrics.keras.metrics as km
metrics = [keras.metrics.AUC(from_logits=True,name='AUC'),
           km.BinaryAccuracy(from_logits=True,name='BinaryAccuracy'), 
           ]
model.compile(loss=loss,metrics=metrics)

# steps=10 for demo purposes
model.evaluate(ds_test,steps=10)
