# Running the TorNet model again (this time with learning rate 10^(-4))

This notebook exists because the learning rate testing (comparing loss reduction after one epoch of training) gave us 10^-3
as the first learning rate to try - but upon trying that learning rate (in a notebook we don't include in this repository),
we found that this failed to converge for our batch size of 64 examples after just a couple of epochs of training. Even 
after the first epoch, we were seeing pretty much a "stalling-out" of the model training process.

With this notebook, we're testing out whether a smaller learning rate of 10^-4 makes sense for validating
the architecture (and for training our augmented dataset that includes mirrored examples).

The goal is to validate the results of the TorNet paper by achieving a reasonable AUC that's at least similar
to that of the initial paper (which had around a .874 AUC on the test dataset when fully trained, and a .81 AUC
on the test dataset after around 2 epochs). If we can see similar results (even with more epochs, since our
learning rate of 10^-4 is quite a bit slower than the original code's default of 10^-3), then we can be
pretty confident that results that we're achieving using the same architecture with the augmented dataset 
are going to be an accurate reflection of whether the augmented dataset is helping or hurting compared to the
original dataset.

In [1]:
import sys

import os
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

# just the location for the input data
TORNET_DATA_INPUT_FOLDER = "/mnt/c/users/handypark/Documents/Grad_School_Courses/CS_230/tornet"

2024-11-18 11:44:20.968090: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-18 11:44:21.074631: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1731959061.120788   15103 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731959061.135104   15103 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-18 11:44:21.241110: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
# Just making sure that we are indeed using GPU-based Tensorflow and not CPU-based Tensorflow.
tf.test.is_built_with_cuda()

# We tried experimental memory growth in some cases, but it didn't work out well (also crashed a lot).
"""
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)
"""

'\ngpus = tf.config.list_physical_devices(\'GPU\')\nif gpus:\n  try:\n    # Currently, memory growth needs to be the same across GPUs\n    for gpu in gpus:\n      tf.config.experimental.set_memory_growth(gpu, True)\n    logical_gpus = tf.config.list_logical_devices(\'GPU\')\n    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")\n  except RuntimeError as e:\n    # Memory growth must be set before GPUs have been initialized\n    print(e)\n'

## TorNet's helper functions

Included below are the main TorNet helper functions we use to create the dataset when loading it into Tensorflow.
There weren't many changes here - these are mostly as is from the tornet repo.
They are annotated with new comments to show how they will be used later on.

In [3]:
"""
TorNet's data loading code, re-imported here manually for loading data into TensorFlow.
For some reason, trying to import the data loading code using `from tornet.data.tf.loader` wasn't working as expected,
so we re-copy that code over here to make use of it.
"""
from typing import List, Dict

from tornet.data.loader import query_catalog, read_file
from tornet.data.constants import ALL_VARIABLES
from tornet.data import preprocess as pp

def create_tf_dataset(files:str,
                      variables: List[str]=ALL_VARIABLES,
                      n_frames:int=1,
                      tilt_last: bool=True) -> tf.data.Dataset:
    """
    This is Tornet's main function for loading the data from the folder where it's all stored.
    
    As they stated, this creates a TF dataset object via the function read_file (which reads the NetCDF files
    into the data one at a time).
    """
    assert len(files)>0
    # grab one file to gets keys, shapes, etc
    data = read_file(files[0],variables=variables,n_frames=n_frames, tilt_last=tilt_last)
    
    output_signature = { k:tf.TensorSpec(shape=data[k].shape,dtype=data[k].dtype,name=k) for k in data }
    def gen():
        for f in files:
            yield read_file(f,variables=variables,n_frames=n_frames, tilt_last=tilt_last)
    ds = tf.data.Dataset.from_generator(gen,
                                        output_signature=output_signature)
    return ds
    

def preproc(ds: tf.data.Dataset,
            weights:Dict=None,
            include_az:bool=False,
            select_keys:list=None,
            tilt_last:bool=True):
    """
    This is Tornet's preprocessing function for taking the raw dataset loaded from the files (in create_tf_dataset)
    and then doing a few things:

    - Remove the time dimension (since we only care about detection at a given time t)
    - Add coordinates (so that we can run CoordConv layers later)
    - Split the data into its inputs and label outputs
    - Adding weights (if we decide to weight the data at all)

    Once the preprocessing is done, the data is basically ready to be trained on.
    """
    
    # Remove time dimesnion
    ds = ds.map(pp.remove_time_dim)

    # Add coordinate tensors
    ds = ds.map(lambda d: pp.add_coordinates(d,include_az=include_az,tilt_last=tilt_last,backend=tf))

    # split into X,y
    ds = ds.map(pp.split_x_y)

    # Add sample weights
    if weights:
        ds = ds.map(lambda x,y:  pp.compute_sample_weight(x,y,**weights, backend=tf) )
    
        # select keys for input
        if select_keys is not None:
            ds = ds.map(lambda x,y,w: (pp.select_keys(x,keys=select_keys),y,w))
    else:
        if select_keys is not None:
            ds = ds.map(lambda x,y: (pp.select_keys(x,keys=select_keys),y))

    return ds

In [4]:
def make_tf_loader(data_root: str, 
            data_type:str='train', # or 'test'
            years: list=list(range(2013,2023)),
            batch_size: int=128, 
            weights: Dict=None,
            include_az: bool=False,
            random_state:int=1234,
            select_keys: list=None,
            tilt_last: bool=True,
            from_tfds: bool=False,
            tfds_data_version: str='1.1.0',
            num_epochs: int=3):
    """
    This TorNet library function is used to load the data into Tensorflow.
    We're going to use the `create_tf_dataset` function from above, 
    then we'll use `preproc` to preprocess it.

    One important note - we tried a bunch of different functions for shuffling 
    and batching, repeating the dataset, etc. to try to be able to run
    many epochs with one function call. It wasn't working.
    Even the `drop_remainder=True` that we've added here to ds.batch
    seems to not really have an effect, as the model training
    still throws an error at the end of training about running out of data.
    """
    
    if from_tfds: # fast loader
        import tensorflow_datasets as tfds
        import tornet.data.tfds.tornet.tornet_dataset_builder # registers 'tornet'
        ds = tfds.load('tornet:%s' % tfds_data_version ,split='+'.join(['%s-%d' % (data_type,y) for y in years]))
        # Assumes data was saved with tilt_last=True and converts it to tilt_last=False
        if not tilt_last:
            ds = ds.map(lambda d: pp.permute_dims(d,(0,3,1,2), backend=tf))
    else: # Load directly from netcdf files
        file_list = query_catalog(data_root, data_type, years, random_state)
        ds = create_tf_dataset(file_list,variables=ALL_VARIABLES,n_frames=1, tilt_last=tilt_last) 

    ds = preproc(ds,weights,include_az,select_keys,tilt_last)
    ds = ds.prefetch(tf.data.AUTOTUNE)

    # this has been adjusted to include drop_remainder=True
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

## TorNet Baseline CNN Model Definition:

As stated in the comment for the code block below, this is just the TorNet model code that was used in the paper.
Our goal is to run this model which consists of:
- Normalizing the inputs
- Adding the coordinate information for CoordConv to work properly
- Running 4 VGG blocks which each have CoordConv2D and two MAXPOOL layers
- Flatten, then dense layers at the end to get the binary classification output

Later on, in our experiments to improve on this model, we'll likely still use some of this code
when it comes to preprocessing the data (from a normalization + CoordConv2D standpoint).

Notably, in the case of our data augmentation experiments, we'll only really need to change the input dataset
(and not any of the model code itself).
In comparison, for the YOLO transfer learning experiments, we'll likely not be able to use a good chunk of this
dataset (but it might still be useful to use some of the layers made here, for instance).

In [5]:
"""
This is just the TorNet model code that was used in the paper.
Goal is to run this model which consists of:
- Normalizing the inputs
- Adding the coordinate information for CoordConv to work properly
- Running 4 VGG blocks which each have CoordConv2D and two MAXPOOL layers
- One last block of Conv2D layers and MAXPOOL to get the output probability
"""

from typing import Dict, List, Tuple
import numpy as np
import keras
from tornet.models.keras.layers import CoordConv2D, FillNaNs
from tornet.data.constants import CHANNEL_MIN_MAX, ALL_VARIABLES


def build_model(shape:Tuple[int]=(120,240,2),
                c_shape:Tuple[int]=(120,240,2),
                input_variables:List[str]=ALL_VARIABLES,
                start_filters:int=64,
                l2_reg:float=0.001,
                background_flag:float=-3.0,
                include_range_folded:bool=True,
                head='maxpool'):
    # Create input layers for each input_variables
    inputs = {}
    for v in input_variables:
        inputs[v]=keras.Input(shape,name=v)
    n_sweeps=shape[2]
    
    # Normalize inputs and concate along channel dim
    normalized_inputs=keras.layers.Concatenate(axis=-1,name='Concatenate1')(
        [normalize(inputs[v],v) for v in input_variables]
        )

    # Replace nan pixel with background flag
    normalized_inputs = FillNaNs(background_flag)(normalized_inputs)

    # Add channel for range folded gates 
    if include_range_folded:
        range_folded = keras.Input(shape[:2]+(n_sweeps,),name='range_folded_mask')
        inputs['range_folded_mask']=range_folded
        normalized_inputs = keras.layers.Concatenate(axis=-1,name='Concatenate2')(
               [normalized_inputs,range_folded])
        
    # Input coordinate information
    cin=keras.Input(c_shape,name='coordinates')
    inputs['coordinates']=cin

    x,c = normalized_inputs,cin
    
    x,c = vgg_block(x,c, filters=start_filters,   ksize=3, l2_reg=l2_reg, n_convs=2, drop_rate=0.1)   # (60,120)
    x,c = vgg_block(x,c, filters=2*start_filters, ksize=3, l2_reg=l2_reg, n_convs=2, drop_rate=0.1)  # (30,60)
    x,c = vgg_block(x,c, filters=4*start_filters, ksize=3, l2_reg=l2_reg, n_convs=3, drop_rate=0.1)  # (15,30)
    x,c = vgg_block(x,c, filters=8*start_filters, ksize=3, l2_reg=l2_reg, n_convs=3, drop_rate=0.1)  # (7,15)
    #x,c = vgg_block(x,c, filters=8*start_filters, ksize=3, l2_reg=l2_reg, n_convs=3)  # (3,7)
    
    if head=='mlp':
        # MLP head
        x = keras.layers.Flatten()(x) 
        x = keras.layers.Dense(units = 4096, activation ='relu')(x) 
        x = keras.layers.Dense(units = 2024, activation ='relu')(x) 
        output = keras.layers.Dense(1)(x)
    elif head=='maxpool':
        # Per gridcell
        x = keras.layers.Conv2D(filters=512, kernel_size=1,
                          kernel_regularizer=keras.regularizers.l2(l2_reg),
                          activation='relu')(x)
        x = keras.layers.Conv2D(filters=256, kernel_size=1,
                          kernel_regularizer=keras.regularizers.l2(l2_reg),
                          activation='relu')(x)
        x = keras.layers.Conv2D(filters=1, kernel_size=1,name='heatmap')(x)
        # Max in scene
        output = keras.layers.GlobalMaxPooling2D()(x)

    return keras.Model(inputs=inputs,outputs=output)


def vgg_block(x,c, filters=64, ksize=3, n_convs=2, l2_reg=1e-6, drop_rate=0.0):

    for _ in range(n_convs):
        x,c = CoordConv2D(filters=filters,
                          kernel_size=ksize,
                          kernel_regularizer=keras.regularizers.l2(l2_reg),
                          padding='same',
                          activation='relu')([x,c])
    x = keras.layers.MaxPool2D(pool_size =2, strides =2, padding ='same')(x)
    c = keras.layers.MaxPool2D(pool_size =2, strides =2, padding ='same')(c)
    if drop_rate>0:
        x = keras.layers.Dropout(rate=drop_rate)(x)
    return x,c


def normalize(x,
              name:str):
    """
    Channel-wise normalization using known CHANNEL_MIN_MAX
    """
    min_max = np.array(CHANNEL_MIN_MAX[name]) # [2,]
    n_sweeps=x.shape[-1]
    
    # choose mean,var to get approximate [-1,1] scaling
    var=((min_max[1]-min_max[0])/2)**2 # scalar
    var=np.array(n_sweeps*[var,])    # [n_sweeps,]
    
    offset=(min_max[0]+min_max[1])/2    # scalar
    offset=np.array(n_sweeps*[offset,]) # [n_sweeps,]

    return keras.layers.Normalization(mean=offset,
                                      variance=var,
                                      name='Normalize_%s' % name)(x)

## Running the Model

Ok, we've got all of the TorNet model components figured out and imported (using the helper functions, etc.).
[Even getting that working turned out to be tricky - there were dependency issues to resolve, and while we'd 
like to just be able to import the functions rather than re-copying them here again, that wasn't working so well.]

In [6]:
model = build_model()

I0000 00:00:1731959070.011985   15103 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5564 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:01:00.0, compute capability: 8.6


In [7]:
model.summary()

Again, as mentioned above, of note here is the learning rate chosen, which is 10^-4.

10^-3 caused divergence in later epochs (even though it must have worked for the TorNet researchers, it
must be due to difference in batch size - our smaller batch size must be causing a larger learning rate 
to diverge more easily in later epochs).

In [8]:
opt = keras.optimizers.Adam(learning_rate=1e-4)
loss = keras.losses.BinaryCrossentropy(from_logits=True)
model.compile(loss=loss, optimizer=opt)

We create a checkpoint saving function just to deal with iPython's many kernel crashes.
After each pass through the data and all batches, we'll save the weights, reload the model, and keep going.

We'll save the weights in `checkpoints/epoch_{number_of_epoch}.weights.h5`

In [9]:
def checkpoint_creator(checkpoint_path):
    # saving the model's weights in case the iPython kernel crashes (which it likes to do)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                       save_weights_only=True,
                                       verbose=1)
    return cp_callback

def checkpoint_loader(checkpoint_path, model):
    model.load_weights(checkpoint_path)

First pass through the data. Again, with each pass, we set it up with:
- A batch_size of 64 (to fit the data into GPU memory)
- Shuffle the data with a random_state (we pick a new seed each time, but record that seed here so we don't reuse it in a later epoch)
- Set up a new checkpoint for each epoch (in case any one epoch of training crashes/fails).

In [10]:
for epoch_num in range(1, 5):
    resampled = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                               data_type = "train", # or 'test'
                               years = list(range(2013, 2023)),
                               batch_size = 64, 
                               weights = None,
                               include_az = False,
                               random_state = 21222324 + epoch_num,
                               select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                               tilt_last = True,
                               from_tfds = False,
                               tfds_data_version ="1.1.0")
    if epoch_num > 1:
        checkpoint_loader("checkpoints/epoch_1e-4_test_{}.weights.h5".format(str(epoch_num - 1)), model)
    model.fit(resampled, epochs=1, callbacks=[checkpoint_creator("checkpoints/epoch_1e-4_test_{}.weights.h5".format(epoch_num))])

I0000 00:00:1731921134.596652   10727 service.cc:148] XLA service 0x7ff8d801a6f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731921134.596800   10727 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3070, Compute Capability 8.6
2024-11-18 01:12:14.709335: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1731921135.018717   10727 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1731921162.192996   10727 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


   2682/Unknown [1m8333s[0m 3s/step - loss: 1.1840
Epoch 1: saving model to checkpoints/epoch_1e-4_test_1.weights.h5


2024-11-18 03:31:02.986426: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-11-18 03:31:02.986663: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_10]]
2024-11-18 03:31:02.986696: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5085392076007066231
2024-11-18 03:31:02.986702: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 7413424529232766513
2024-11-18 03:31:02.986707: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5447956165592526585
2024-11-18 03:31:02.986712: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14727562599619413896
2024-11-1

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8334s[0m 3s/step - loss: 1.1838
   2682/Unknown [1m8396s[0m 3s/step - loss: 0.2602
Epoch 1: saving model to checkpoints/epoch_1e-4_test_2.weights.h5


2024-11-18 05:51:01.241504: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_10]]
2024-11-18 05:51:01.243992: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5085392076007066231
2024-11-18 05:51:01.244043: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 7413424529232766513
2024-11-18 05:51:01.244049: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5447956165592526585
2024-11-18 05:51:01.244055: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14727562599619413896
2024-11-18 05:51:01.244057: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11069779641631512876
2024-11-18 05:51:01.244061: I tensorflow/co

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8397s[0m 3s/step - loss: 0.2602


2024-11-18 05:51:05.895293: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 117964864 bytes after encountering the first element of size 117964864 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size


   2682/Unknown [1m8426s[0m 3s/step - loss: 0.2322
Epoch 1: saving model to checkpoints/epoch_1e-4_test_3.weights.h5


2024-11-18 08:11:29.432565: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5085392076007066231
2024-11-18 08:11:29.434689: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 7413424529232766513
2024-11-18 08:11:29.434812: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5447956165592526585
2024-11-18 08:11:29.434826: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14727562599619413896
2024-11-18 08:11:29.434830: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11069779641631512876
2024-11-18 08:11:29.434835: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1393256507155917626
2024-11-18 08:11:29.434838: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv i

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8427s[0m 3s/step - loss: 0.2322
   2682/Unknown [1m8560s[0m 3s/step - loss: 0.2203
Epoch 1: saving model to checkpoints/epoch_1e-4_test_4.weights.h5


2024-11-18 10:34:12.381075: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_10]]
2024-11-18 10:34:12.385978: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5085392076007066231
2024-11-18 10:34:12.386001: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 7413424529232766513
2024-11-18 10:34:12.386053: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5447956165592526585
2024-11-18 10:34:12.386075: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14727562599619413896
2024-11-18 10:34:12.386080: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11069779641631512876
2024-11-18 10:34:12.386086: I tensorflow/co

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8561s[0m 3s/step - loss: 0.2203


In [None]:
test_data = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                              data_type = "test",
                              years = list(range(2013, 2023)),
                              batch_size = 64, 
                              weights = None,
                              include_az = False,
                              random_state = 5678,
                              select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                              tilt_last = True,
                              from_tfds = False,
                              tfds_data_version ="1.1.0")

for epoch_num in range(1, 5):
    metrics = [keras.metrics.AUC(from_logits=True,name='AUC')]
    checkpoint_loader("checkpoints/epoch_1e-4_test_{}.weights.h5".format(str(epoch_num)), model)
    model.compile(loss=loss, metrics=metrics)
    model.evaluate(test_data)

  saveable.load_own_variables(weights_store.get(inner_path))


[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1602s[0m 3s/step - AUC: 0.6884 - loss: 0.2697


2024-11-18 11:16:57.250390: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5085392076007066231
2024-11-18 11:16:57.251306: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 7413424529232766513
2024-11-18 11:16:57.251314: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5447956165592526585
2024-11-18 11:16:57.251320: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14727562599619413896
2024-11-18 11:16:57.251323: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11069779641631512876
2024-11-18 11:16:57.251327: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1393256507155917626
2024-11-18 11:16:57.251330: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv i

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1585s[0m 3s/step - AUC: 0.7462 - loss: 0.2362


2024-11-18 11:43:22.528727: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5085392076007066231
2024-11-18 11:43:22.529724: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 7413424529232766513
2024-11-18 11:43:22.529737: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5447956165592526585
2024-11-18 11:43:22.529744: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14727562599619413896
2024-11-18 11:43:22.529747: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11069779641631512876
2024-11-18 11:43:22.529751: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1393256507155917626
2024-11-18 11:43:22.529755: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv i

      6/Unknown [1m21s[0m 3s/step - AUC: 0.5984 - loss: 0.1884

In [10]:
test_data = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                              data_type = "test",
                              years = list(range(2013, 2023)),
                              batch_size = 64, 
                              weights = None,
                              include_az = False,
                              random_state = 5678,
                              select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                              tilt_last = True,
                              from_tfds = False,
                              tfds_data_version ="1.1.0")

for epoch_num in range(3, 5):
    metrics = [keras.metrics.AUC(from_logits=True,name='AUC')]
    checkpoint_loader("checkpoints/epoch_1e-4_test_{}.weights.h5".format(str(epoch_num)), model)
    model.compile(loss=loss, metrics=metrics)
    model.evaluate(test_data)

  saveable.load_own_variables(weights_store.get(inner_path))
I0000 00:00:1731959099.802115   15192 service.cc:148] XLA service 0x7efb34003a80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731959099.812486   15192 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3070, Compute Capability 8.6
2024-11-18 11:44:59.851374: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1731959099.942555   15192 cuda_dnn.cc:529] Loaded cuDNN version 90300


      2/Unknown [1m11s[0m 61ms/step - AUC: 0.3646 - loss: 0.1400   

I0000 00:00:1731959106.498013   15192 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1564s[0m 3s/step - AUC: 0.7861 - loss: 0.2157


2024-11-18 12:11:00.574032: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-11-18 12:11:00.574088: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_18]]
2024-11-18 12:11:00.574099: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 9638904989804753353
2024-11-18 12:11:00.574104: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 4288215728630536819
2024-11-18 12:11:00.574111: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14349500331753286841
2024-11-18 12:11:00.574116: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6954975258179042439
2024-11-1

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1613s[0m 3s/step - AUC: 0.8153 - loss: 0.2007


2024-11-18 12:37:53.547872: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_12]]
  self.gen.throw(typ, value, traceback)
2024-11-18 12:37:53.548003: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 9638904989804753353
2024-11-18 12:37:53.548011: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 4288215728630536819
2024-11-18 12:37:53.548016: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14609499663039341637
2024-11-18 12:37:53.548020: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6954975258179042439
2024-11-18 12:37:53.548023: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14349500331753286841
202

## Results
| Epoch Number | Training Loss | Validation AUC |
|--------------|---------------|----------------|
| 1            | 1.1840        | 0.6884         |
| 2            | 0.2602        | 0.7462         |
| 3            | 0.2322        | 0.7861         |
| 4            | 0.2203        | 0.8153         |

We'll be running more epochs in the coming days just to double-check, but so far, our learning rate choice seems a lot better than 10^-3 (which was diverging after just one epoch of data).
 |   |

With only four epochs completed (and still more progress possible, seemingly), it seems like validation AUC is continuing to increase with each iteration. According to the paper, we should eventually see around a .874 validation AUC after training is done, but we already see .815 after four epochs. (In the paper, we reached .81 after two epochs, but the paper seems to have been using (if their code's default parameters were being used) a learning rate that was 10x in scale - so it's not too much of a surprise that our version is requiring a few more epochs to get to the same degree of success.)

Of note - it might be worth considering a learning rate decay for later epochs of training here. Given the smaller batch size that we're using compared to the original paper, learning rate decay might help dampen any noisiness being caused by the size of the batches as we approach the optimal weights for the model. This insight might also come in handy for when we train our later epochs in the augmented dataset runs of the model (which we hope improves performance by increasing the number of tornadic examples).