# Running the TorNet model

The TorNet paper gave us a pretty good indicator of how to set up their baseline CNN model with CoordConv and VGG blocks.
But even better, their actual TorNet repository gave us the actual model blocks themselves.

Unfortunately, running their model out of the box with default settings does NOT perform as anticipated. 
The reasons for this were as follows:
- Trying to run with a batch size of 128 was too large for the GPU we use here.
    - We had to adjust the batch size until it would fit in memory - a good size turned out to be 64.
- Using a learning rate (as demonstrated in their "how to train the model" examples) of 1e-3 also was too large, even with an Adam optimization algorithm.
    - The paper doesn't list any learning rate in particular - it only states that hyperparameters were found using a grid search.
    - After several trials and errors (searching on a log scale between 1e-3 and 1e-6), we found that around 1e-3 seemed to work reasonably.
- In the paper, we were supposed to run at least around 3 epochs to start seeing reasonable performance. However, the iPython kernel kept dying in spite of the GPU being fine and healthy on each run (even after a single epoch).
    - If we try to run more epochs using .repeat() (with shuffling, etc.), the iPython kernel crashes without fail.
        - It's unclear why this occurs. In theory, even if the dataset "runs out of data" the model should just stop training, save the existing state, and continue...but that hasn't been the case so far. For some reason, the "run out of data" exception seems to kill the kernel from time to time.
    - The only successful strategy (so far) has been the following:
        - For each epoch:
            - Reshuffle the data into a new TF dataset 
            - Batch the dataset into `batch_size`-sized batches
                - In this case, we'd have 2,682 batches with 64 training examples each.
                - Notably, even with `drop_remainder = True` Tensorflow still seems to complain about running out of data for a single epoch.
            - Run a single epoch on that dataset
            - Save the checkpoint from the model immediately (in case of iPython kernel crash)
    - If the kernel ever crashes, we can then recover from the model checkpoint that we've saved weights to.

Using the fixes, modifications, and adjustments above, we're going to try to reproduce the performance of the original model from TorNet (to within a reasonable degree) to show that their methodology does indeed result in a model with good predictive power [which we then hope to improve on].

In [1]:
import sys

import os
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

# just the location for the input data
TORNET_DATA_INPUT_FOLDER = "/mnt/c/users/handypark/Documents/Grad_School_Courses/CS_230/tornet"

2024-11-15 10:46:04.823945: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-15 10:46:05.007354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1731696365.123059    2159 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731696365.157377    2159 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-15 10:46:05.353461: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
# Just making sure that we are indeed using GPU-based Tensorflow and not CPU-based Tensorflow.
tf.test.is_built_with_cuda()

# We tried experimental memory growth in some cases, but it didn't work out well (also crashed a lot).
"""
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)
"""

'\ngpus = tf.config.list_physical_devices(\'GPU\')\nif gpus:\n  try:\n    # Currently, memory growth needs to be the same across GPUs\n    for gpu in gpus:\n      tf.config.experimental.set_memory_growth(gpu, True)\n    logical_gpus = tf.config.list_logical_devices(\'GPU\')\n    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")\n  except RuntimeError as e:\n    # Memory growth must be set before GPUs have been initialized\n    print(e)\n'

## TorNet's helper functions

Included below are the main TorNet helper functions we use to create the dataset when loading it into Tensorflow.
There weren't many changes here - these are mostly as is from the tornet repo.
They are annotated with new comments to show how they will be used later on.

In [3]:
"""
TorNet's data loading code, re-imported here manually for loading data into TensorFlow.
For some reason, trying to import the data loading code using `from tornet.data.tf.loader` wasn't working as expected,
so we re-copy that code over here to make use of it.
"""
from typing import List, Dict

from tornet.data.loader import query_catalog, read_file
from tornet.data.constants import ALL_VARIABLES
from tornet.data import preprocess as pp

def create_tf_dataset(files:str,
                      variables: List[str]=ALL_VARIABLES,
                      n_frames:int=1,
                      tilt_last: bool=True) -> tf.data.Dataset:
    """
    This is Tornet's main function for loading the data from the folder where it's all stored.
    
    As they stated, this creates a TF dataset object via the function read_file (which reads the NetCDF files
    into the data one at a time).
    """
    assert len(files)>0
    # grab one file to gets keys, shapes, etc
    data = read_file(files[0],variables=variables,n_frames=n_frames, tilt_last=tilt_last)
    
    output_signature = { k:tf.TensorSpec(shape=data[k].shape,dtype=data[k].dtype,name=k) for k in data }
    def gen():
        for f in files:
            yield read_file(f,variables=variables,n_frames=n_frames, tilt_last=tilt_last)
    ds = tf.data.Dataset.from_generator(gen,
                                        output_signature=output_signature)
    return ds
    

def preproc(ds: tf.data.Dataset,
            weights:Dict=None,
            include_az:bool=False,
            select_keys:list=None,
            tilt_last:bool=True):
    """
    This is Tornet's preprocessing function for taking the raw dataset loaded from the files (in create_tf_dataset)
    and then doing a few things:

    - Remove the time dimension (since we only care about detection at a given time t)
    - Add coordinates (so that we can run CoordConv layers later)
    - Split the data into its inputs and label outputs
    - Adding weights (if we decide to weight the data at all)

    Once the preprocessing is done, the data is basically ready to be trained on.
    """
    
    # Remove time dimesnion
    ds = ds.map(pp.remove_time_dim)

    # Add coordinate tensors
    ds = ds.map(lambda d: pp.add_coordinates(d,include_az=include_az,tilt_last=tilt_last,backend=tf))

    # split into X,y
    ds = ds.map(pp.split_x_y)

    # Add sample weights
    if weights:
        ds = ds.map(lambda x,y:  pp.compute_sample_weight(x,y,**weights, backend=tf) )
    
        # select keys for input
        if select_keys is not None:
            ds = ds.map(lambda x,y,w: (pp.select_keys(x,keys=select_keys),y,w))
    else:
        if select_keys is not None:
            ds = ds.map(lambda x,y: (pp.select_keys(x,keys=select_keys),y))

    return ds

In [4]:
def make_tf_loader(data_root: str, 
            data_type:str='train', # or 'test'
            years: list=list(range(2013,2023)),
            batch_size: int=128, 
            weights: Dict=None,
            include_az: bool=False,
            random_state:int=1234,
            select_keys: list=None,
            tilt_last: bool=True,
            from_tfds: bool=False,
            tfds_data_version: str='1.1.0',
            num_epochs: int=3):
    """
    This TorNet library function is used to load the data into Tensorflow.
    We're going to use the `create_tf_dataset` function from above, 
    then we'll use `preproc` to preprocess it.

    One important note - we tried a bunch of different functions for shuffling 
    and batching, repeating the dataset, etc. to try to be able to run
    many epochs with one function call. It wasn't working.
    Even the `drop_remainder=True` that we've added here to ds.batch
    seems to not really have an effect, as the model training
    still throws an error at the end of training about running out of data.
    """
    
    if from_tfds: # fast loader
        import tensorflow_datasets as tfds
        import tornet.data.tfds.tornet.tornet_dataset_builder # registers 'tornet'
        ds = tfds.load('tornet:%s' % tfds_data_version ,split='+'.join(['%s-%d' % (data_type,y) for y in years]))
        # Assumes data was saved with tilt_last=True and converts it to tilt_last=False
        if not tilt_last:
            ds = ds.map(lambda d: pp.permute_dims(d,(0,3,1,2), backend=tf))
    else: # Load directly from netcdf files
        file_list = query_catalog(data_root, data_type, years, random_state)
        ds = create_tf_dataset(file_list,variables=ALL_VARIABLES,n_frames=1, tilt_last=tilt_last) 

    ds = preproc(ds,weights,include_az,select_keys,tilt_last)
    ds = ds.prefetch(tf.data.AUTOTUNE)

    # this has been adjusted to include drop_remainder=True
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

## TorNet Baseline CNN Model Definition:

As stated in the comment for the code block below, this is just the TorNet model code that was used in the paper.
Our goal is to run this model which consists of:
- Normalizing the inputs
- Adding the coordinate information for CoordConv to work properly
- Running 4 VGG blocks which each have CoordConv2D and two MAXPOOL layers
- Flatten, then dense layers at the end to get the binary classification output

Later on, in our experiments to improve on this model, we'll likely still use some of this code
when it comes to preprocessing the data (from a normalization + CoordConv2D standpoint).

Notably, in the case of our data augmentation experiments, we'll only really need to change the input dataset
(and not any of the model code itself).
In comparison, for the YOLO transfer learning experiments, we'll likely not be able to use a good chunk of this
dataset (but it might still be useful to use some of the layers made here, for instance).

In [5]:
"""
This is just the TorNet model code that was used in the paper.
Goal is to run this model which consists of:
- Normalizing the inputs
- Adding the coordinate information for CoordConv to work properly
- Running 4 VGG blocks which each have CoordConv2D and two MAXPOOL layers
- One last block of Conv2D layers and MAXPOOL to get the output probability
"""

from typing import Dict, List, Tuple
import numpy as np
import keras
from tornet.models.keras.layers import CoordConv2D, FillNaNs
from tornet.data.constants import CHANNEL_MIN_MAX, ALL_VARIABLES


def build_model(shape:Tuple[int]=(120,240,2),
                c_shape:Tuple[int]=(120,240,2),
                input_variables:List[str]=ALL_VARIABLES,
                start_filters:int=64,
                l2_reg:float=0.001,
                background_flag:float=-3.0,
                include_range_folded:bool=True,
                head='maxpool'):
    # Create input layers for each input_variables
    inputs = {}
    for v in input_variables:
        inputs[v]=keras.Input(shape,name=v)
    n_sweeps=shape[2]
    
    # Normalize inputs and concate along channel dim
    normalized_inputs=keras.layers.Concatenate(axis=-1,name='Concatenate1')(
        [normalize(inputs[v],v) for v in input_variables]
        )

    # Replace nan pixel with background flag
    normalized_inputs = FillNaNs(background_flag)(normalized_inputs)

    # Add channel for range folded gates 
    if include_range_folded:
        range_folded = keras.Input(shape[:2]+(n_sweeps,),name='range_folded_mask')
        inputs['range_folded_mask']=range_folded
        normalized_inputs = keras.layers.Concatenate(axis=-1,name='Concatenate2')(
               [normalized_inputs,range_folded])
        
    # Input coordinate information
    cin=keras.Input(c_shape,name='coordinates')
    inputs['coordinates']=cin

    x,c = normalized_inputs,cin
    
    x,c = vgg_block(x,c, filters=start_filters,   ksize=3, l2_reg=l2_reg, n_convs=2, drop_rate=0.1)   # (60,120)
    x,c = vgg_block(x,c, filters=2*start_filters, ksize=3, l2_reg=l2_reg, n_convs=2, drop_rate=0.1)  # (30,60)
    x,c = vgg_block(x,c, filters=4*start_filters, ksize=3, l2_reg=l2_reg, n_convs=3, drop_rate=0.1)  # (15,30)
    x,c = vgg_block(x,c, filters=8*start_filters, ksize=3, l2_reg=l2_reg, n_convs=3, drop_rate=0.1)  # (7,15)
    #x,c = vgg_block(x,c, filters=8*start_filters, ksize=3, l2_reg=l2_reg, n_convs=3)  # (3,7)
    
    if head=='mlp':
        # MLP head
        x = keras.layers.Flatten()(x) 
        x = keras.layers.Dense(units = 4096, activation ='relu')(x) 
        x = keras.layers.Dense(units = 2024, activation ='relu')(x) 
        output = keras.layers.Dense(1)(x)
    elif head=='maxpool':
        # Per gridcell
        x = keras.layers.Conv2D(filters=512, kernel_size=1,
                          kernel_regularizer=keras.regularizers.l2(l2_reg),
                          activation='relu')(x)
        x = keras.layers.Conv2D(filters=256, kernel_size=1,
                          kernel_regularizer=keras.regularizers.l2(l2_reg),
                          activation='relu')(x)
        x = keras.layers.Conv2D(filters=1, kernel_size=1,name='heatmap')(x)
        # Max in scene
        output = keras.layers.GlobalMaxPooling2D()(x)

    return keras.Model(inputs=inputs,outputs=output)


def vgg_block(x,c, filters=64, ksize=3, n_convs=2, l2_reg=1e-6, drop_rate=0.0):

    for _ in range(n_convs):
        x,c = CoordConv2D(filters=filters,
                          kernel_size=ksize,
                          kernel_regularizer=keras.regularizers.l2(l2_reg),
                          padding='same',
                          activation='relu')([x,c])
    x = keras.layers.MaxPool2D(pool_size =2, strides =2, padding ='same')(x)
    c = keras.layers.MaxPool2D(pool_size =2, strides =2, padding ='same')(c)
    if drop_rate>0:
        x = keras.layers.Dropout(rate=drop_rate)(x)
    return x,c


def normalize(x,
              name:str):
    """
    Channel-wise normalization using known CHANNEL_MIN_MAX
    """
    min_max = np.array(CHANNEL_MIN_MAX[name]) # [2,]
    n_sweeps=x.shape[-1]
    
    # choose mean,var to get approximate [-1,1] scaling
    var=((min_max[1]-min_max[0])/2)**2 # scalar
    var=np.array(n_sweeps*[var,])    # [n_sweeps,]
    
    offset=(min_max[0]+min_max[1])/2    # scalar
    offset=np.array(n_sweeps*[offset,]) # [n_sweeps,]

    return keras.layers.Normalization(mean=offset,
                                      variance=var,
                                      name='Normalize_%s' % name)(x)

## Running the Model

Ok, we've got all of the TorNet model components figured out and imported (using the helper functions, etc.).
[Even getting that working turned out to be tricky - there were dependency issues to resolve, and while we'd 
like to just be able to import the functions rather than re-copying them here again, that wasn't working so well.]

In [6]:
model = build_model()

I0000 00:00:1731696386.667789    2159 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5564 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:01:00.0, compute capability: 8.6


In [7]:
model.summary()

Again, as mentioned above, of note here is the learning rate chosen, which is 1.5e-6.

We had to test out a bunch of different learning rate choices on a log scale before we found that this scale would not cause gradients to vanish or explode.

In [8]:
opt = keras.optimizers.Adam(learning_rate=1.5e-6)
loss = keras.losses.BinaryCrossentropy(from_logits=True)
model.compile(loss=loss, optimizer=opt)

We create a checkpoint saving function just to deal with iPython's many kernel crashes.
After each pass through the data and all batches, we'll save the weights, reload the model, and keep going.

We'll save the weights in `checkpoints/epoch_{number_of_epoch}.weights.h5`

In [9]:
def checkpoint_creator(checkpoint_path):
    # saving the model's weights in case the iPython kernel crashes (which it likes to do)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                       save_weights_only=True,
                                       verbose=1)
    return cp_callback

def checkpoint_loader(checkpoint_path, model):
    model.load_weights(checkpoint_path)

First pass through the data. Again, with each pass, we set it up with:
- A batch_size of 64 (to fit the data into GPU memory)
- Shuffle the data with a random_state (we pick a new seed each time, but record that seed here so we don't reuse it in a later epoch)
- Set up a new checkpoint for each epoch (in case any one epoch of training crashes/fails).

In [15]:
preprocessed = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                              data_type = "train", # or 'test'
                              years = list(range(2013, 2023)),
                              batch_size = 64, 
                              weights = None,
                              include_az = False,
                              random_state = 5678,
                              select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                              tilt_last = True,
                              from_tfds = False,
                              tfds_data_version ="1.1.0")

In [16]:
model.fit(preprocessed, epochs=1, callbacks=[checkpoint_creator("checkpoints/epoch_1.weights.h5")])

I0000 00:00:1731561021.211059     123 service.cc:148] XLA service 0x7f62b4018dc0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731561021.211288     123 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3070, Compute Capability 8.6
2024-11-13 21:10:21.325432: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1731561021.646291     123 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1731561048.909073     123 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


   2682/Unknown [1m8494s[0m 3s/step - loss: 3.4064
Epoch 1: saving model to checkpoints/epoch_1.weights.h5


2024-11-13 23:31:50.022932: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-11-13 23:31:50.023266: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_16]]
2024-11-13 23:31:50.023730: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5090890319797716369
2024-11-13 23:31:50.023751: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 17464119409452371891
2024-11-13 23:31:50.023775: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1750716295784296344
2024-11-13 23:31:50.023779: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16184661796340794336
2024-11-

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8495s[0m 3s/step - loss: 3.4063


<keras.src.callbacks.history.History at 0x7f63c40c3b80>

In [17]:
resample_2 = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                            data_type = "train", # or 'test'
                            years = list(range(2013, 2023)),
                            batch_size = 64, 
                            weights = None,
                            include_az = False,
                            random_state = 9101112,
                            select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                            tilt_last = True,
                            from_tfds = False,
                            tfds_data_version ="1.1.0")

In [19]:
# load weights from the previous iteration
checkpoint_loader("checkpoints/epoch_1.weights.h5", model)
model.fit(resample_2, epochs=1, callbacks=[checkpoint_creator("checkpoints/epoch_2.weights.h5")])

   2682/Unknown [1m8202s[0m 3s/step - loss: 2.8337
Epoch 1: saving model to checkpoints/epoch_2.weights.h5


2024-11-14 03:19:51.753462: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_16]]
  self.gen.throw(typ, value, traceback)
2024-11-14 03:19:51.756199: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5090890319797716369
2024-11-14 03:19:51.756213: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 17464119409452371891
2024-11-14 03:19:51.756221: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1750716295784296344
2024-11-14 03:19:51.756224: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16184661796340794336
2024-11-14 03:19:51.756228: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11418751161415009090
20

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8203s[0m 3s/step - loss: 2.8336


<keras.src.callbacks.history.History at 0x7f628c7bfca0>

In [21]:
resample_3 = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                            data_type = "train", # or 'test'
                            years = list(range(2013, 2023)),
                            batch_size = 64, 
                            weights = None,
                            include_az = False,
                            random_state = 13141516,
                            select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                            tilt_last = True,
                            from_tfds = False,
                            tfds_data_version ="1.1.0")

# load weights from the previous iteration
checkpoint_loader("checkpoints/epoch_2.weights.h5", model)
model.fit(resample_3, epochs=1, callbacks=[checkpoint_creator("checkpoints/epoch_3.weights.h5")])

2024-11-14 09:51:14.792883: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 117964864 bytes after encountering the first element of size 117964864 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size


   2682/Unknown [1m8552s[0m 3s/step - loss: 2.4142
Epoch 1: saving model to checkpoints/epoch_3.weights.h5


2024-11-14 12:13:43.901597: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5090890319797716369
2024-11-14 12:13:43.903729: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 17464119409452371891
2024-11-14 12:13:43.903745: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1750716295784296344
2024-11-14 12:13:43.903749: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16184661796340794336
2024-11-14 12:13:43.903752: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11418751161415009090
2024-11-14 12:13:43.903756: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 7629939211194108882
2024-11-14 12:13:43.903758: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv 

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8553s[0m 3s/step - loss: 2.4142


<keras.src.callbacks.history.History at 0x7f628a401100>

In [22]:
resample_4 = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                            data_type = "train", # or 'test'
                            years = list(range(2013, 2023)),
                            batch_size = 64, 
                            weights = None,
                            include_az = False,
                            random_state = 17181920,
                            select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                            tilt_last = True,
                            from_tfds = False,
                            tfds_data_version ="1.1.0")

# load weights from the previous iteration
checkpoint_loader("checkpoints/epoch_3.weights.h5", model)
model.fit(resample_4, epochs=1, callbacks=[checkpoint_creator("checkpoints/epoch_4.weights.h5")])

   2682/Unknown [1m8814s[0m 3s/step - loss: 2.0757
Epoch 1: saving model to checkpoints/epoch_4.weights.h5


2024-11-14 14:40:42.147395: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_16]]
2024-11-14 14:40:42.150001: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5090890319797716369
2024-11-14 14:40:42.150012: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 17464119409452371891
2024-11-14 14:40:42.150030: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1750716295784296344
2024-11-14 14:40:42.150035: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16184661796340794336
2024-11-14 14:40:42.150039: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11418751161415009090
2024-11-14 14:40:42.150043: I tensorflow/c

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8816s[0m 3s/step - loss: 2.0757


<keras.src.callbacks.history.History at 0x7f6285f38b80>

In [None]:
for epoch_num in range(5, 11):
    resampled = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                               data_type = "train", # or 'test'
                               years = list(range(2013, 2023)),
                               batch_size = 64, 
                               weights = None,
                               include_az = False,
                               random_state = 21222324 + epoch_num,
                               select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                               tilt_last = True,
                               from_tfds = False,
                               tfds_data_version ="1.1.0")
    checkpoint_loader("checkpoints/epoch_{}.weights.h5".format(str(epoch_num - 1)), model)
    model.fit(resampled, epochs=1, callbacks=[checkpoint_creator("checkpoints/epoch_{}.weights.h5".format(epoch_num))])

  saveable.load_own_variables(weights_store.get(inner_path))
I0000 00:00:1731657636.601262     158 service.cc:148] XLA service 0x7f2a3c018f60 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731657636.601541     158 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3070, Compute Capability 8.6
2024-11-15 00:00:36.714626: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1731657637.008340     158 cuda_dnn.cc:529] Loaded cuDNN version 90300
2024-11-15 00:00:50.275453: E external/local_xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[64,66,120,240]{3,2,1,0}, u8[0]{0}) custom-call(f32[64,64,120,240]{3,2,1,0}, f32[64,66,3,3]{3,2,1,0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBackwardInput", backend_config={"cudnn_conv_backend_

   2682/Unknown [1m7868s[0m 3s/step - loss: 1.8027
Epoch 1: saving model to checkpoints/epoch_5.weights.h5
[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7869s[0m 3s/step - loss: 1.8027


2024-11-15 02:11:40.366837: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-11-15 02:11:40.366870: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_8]]
2024-11-15 02:11:40.366887: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 10096804583211940195
2024-11-15 02:11:40.366892: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 4071073031941817423
2024-11-15 02:11:40.366897: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15403878188317645847
2024-11-15 02:11:40.366901: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 9429846262897810975
2024-11-1

   2682/Unknown [1m7800s[0m 3s/step - loss: 1.5928
Epoch 1: saving model to checkpoints/epoch_6.weights.h5


2024-11-15 04:21:41.627528: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_8]]
2024-11-15 04:21:41.629586: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 10096804583211940195
2024-11-15 04:21:41.629597: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 4071073031941817423
2024-11-15 04:21:41.629606: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15403878188317645847
2024-11-15 04:21:41.629612: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 9429846262897810975
2024-11-15 04:21:41.629615: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 4583138412789806657
2024-11-15 04:21:41.629621: I tensorflow/cor

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7800s[0m 3s/step - loss: 1.5928


2024-11-15 04:21:45.785778: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 117964864 bytes after encountering the first element of size 117964864 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size


   2682/Unknown [1m7777s[0m 3s/step - loss: 1.4316
Epoch 1: saving model to checkpoints/epoch_7.weights.h5


2024-11-15 06:31:19.987493: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 10096804583211940195
2024-11-15 06:31:19.989670: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 4071073031941817423
2024-11-15 06:31:19.989913: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15403878188317645847
2024-11-15 06:31:19.989918: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 9429846262897810975
2024-11-15 06:31:19.989922: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 4583138412789806657
2024-11-15 06:31:19.989928: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11496368188089280258
2024-11-15 06:31:19.989931: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv 

[1m2682/2682[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7778s[0m 3s/step - loss: 1.4316


2024-11-15 06:31:25.847160: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 117964864 bytes after encountering the first element of size 117964864 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size


   1423/Unknown [1m15055s[0m 11s/step - loss: 1.3196

In [10]:
for epoch_num in range(8, 11):
    resampled = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                               data_type = "train", # or 'test'
                               years = list(range(2013, 2023)),
                               batch_size = 64, 
                               weights = None,
                               include_az = False,
                               random_state = 21222324 + epoch_num,
                               select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                               tilt_last = True,
                               from_tfds = False,
                               tfds_data_version ="1.1.0")
    checkpoint_loader("checkpoints/epoch_{}.weights.h5".format(str(epoch_num - 1)), model)
    model.fit(resampled, epochs=1, callbacks=[checkpoint_creator("checkpoints/epoch_{}.weights.h5".format(epoch_num))])

  saveable.load_own_variables(weights_store.get(inner_path))
I0000 00:00:1731696404.942044    2274 service.cc:148] XLA service 0x7f7e08005e10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731696404.942332    2274 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3070, Compute Capability 8.6
2024-11-15 10:46:45.071143: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1731696405.412703    2274 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1731696428.801904    2274 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


     26/Unknown [1m99s[0m 3s/step - loss: 1.3240

KeyboardInterrupt: 

In [12]:
test_data = make_tf_loader(data_root = TORNET_DATA_INPUT_FOLDER, 
                              data_type = "test",
                              years = list(range(2013, 2023)),
                              batch_size = 64, 
                              weights = None,
                              include_az = False,
                              random_state = 5678,
                              select_keys = ALL_VARIABLES + ["coordinates", "range_folded_mask"],
                              tilt_last = True,
                              from_tfds = False,
                              tfds_data_version ="1.1.0")

for epoch_num in range(1, 8):
    metrics = [keras.metrics.AUC(from_logits=True,name='AUC')]
    checkpoint_loader("checkpoints/epoch_{}.weights.h5".format(str(epoch_num)), model)
    model.compile(loss=loss, metrics=metrics)
    model.evaluate(test_data)

  saveable.load_own_variables(weights_store.get(inner_path))


[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1434s[0m 3s/step - AUC: 0.6317 - loss: 2.9407


2024-11-15 11:13:42.744267: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-11-15 11:13:42.744310: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]
2024-11-15 11:13:42.744320: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11547744869511264499
2024-11-15 11:13:42.744325: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15588538101070533201
2024-11-15 11:13:42.744331: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5916350188155389167
2024-11-15 11:13:42.744335: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5719100905002106637
2024-11-1

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1441s[0m 3s/step - AUC: 0.6375 - loss: 2.5019


2024-11-15 11:37:43.610944: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]
2024-11-15 11:37:43.610996: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11547744869511264499
2024-11-15 11:37:43.611001: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15588538101070533201
2024-11-15 11:37:43.611006: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5916350188155389167
2024-11-15 11:37:43.611010: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5719100905002106637
2024-11-15 11:37:43.611016: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 13057835299142942990
2024-11-15 11:37:43.611019: I tensorflow/co

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1442s[0m 3s/step - AUC: 0.6439 - loss: 2.1461


2024-11-15 12:01:45.791310: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11547744869511264499
2024-11-15 12:01:45.793123: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15588538101070533201
2024-11-15 12:01:45.793133: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5916350188155389167
2024-11-15 12:01:45.793137: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5719100905002106637
2024-11-15 12:01:45.793145: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 13057835299142942990
2024-11-15 12:01:45.793150: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14002001117200722730
2024-11-15 12:01:45.793154: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1477s[0m 3s/step - AUC: 0.6501 - loss: 1.8581


2024-11-15 12:26:22.954071: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]
  self.gen.throw(typ, value, traceback)
2024-11-15 12:26:22.954110: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11547744869511264499
2024-11-15 12:26:22.954116: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15588538101070533201
2024-11-15 12:26:22.954120: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5916350188155389167
2024-11-15 12:26:22.954125: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5719100905002106637
2024-11-15 12:26:22.954132: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 13057835299142942990
202

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1495s[0m 3s/step - AUC: 0.6569 - loss: 1.6297


2024-11-15 12:51:17.833279: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11547744869511264499
2024-11-15 12:51:17.833395: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15588538101070533201
2024-11-15 12:51:17.833405: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5916350188155389167
2024-11-15 12:51:17.833410: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5719100905002106637
2024-11-15 12:51:17.833491: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 13057835299142942990
2024-11-15 12:51:17.833512: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14002001117200722730
2024-11-15 12:51:17.833519: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1512s[0m 3s/step - AUC: 0.6613 - loss: 1.4592


2024-11-15 13:16:29.819839: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11547744869511264499
  self.gen.throw(typ, value, traceback)
2024-11-15 13:16:29.819878: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15588538101070533201
2024-11-15 13:16:29.819885: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5916350188155389167
2024-11-15 13:16:29.819891: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5719100905002106637
2024-11-15 13:16:29.819900: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 13057835299142942990
2024-11-15 13:16:29.819904: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14002001117200722730
2024-11-15 13:16:29.819908: I tensorflow/core/framework/local_

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1519s[0m 3s/step - AUC: 0.6603 - loss: 1.3301


2024-11-15 13:41:48.748440: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11547744869511264499
2024-11-15 13:41:48.750401: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15588538101070533201
2024-11-15 13:41:48.750409: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5916350188155389167
2024-11-15 13:41:48.750413: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5719100905002106637
2024-11-15 13:41:48.750421: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 13057835299142942990
2024-11-15 13:41:48.750427: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14002001117200722730
2024-11-15 13:41:48.750431: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv