In [1]:
import argparse
import os
from itertools import islice

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm.keras import TqdmCallback
from wandb.keras import WandbMetricsLogger
# Import gradcam
from tf_explain.core.grad_cam import GradCAM
import wandb
from configure_dataframes import directory_to_dataframe
from data_preparation_utils import get_datasets
from metric_utils import log_wandb_print_class_report, plot_roc_curve
from modelbuilder import ModelBuilder, TransferLearningModelBuilder
from train_utils import load_config
from tensorflow.keras.models import load_model

2023-08-17 22:26:51.225385: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-17 22:26:51.241262: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-08-17 22:26:51.435972: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-08-17 22:26:51.437187: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## PARAMS
config_name = 'transfer_4'
project_name = 'radio_sunburst_detection_main'
entity = 'i4ds_radio_sunburst_detection'

In [3]:
# Send config to wandb
config = load_config(os.path.join("model_base_configs", config_name + ".yaml"))
wandb.init(
    project=project_name,
    config=config,
    entity=entity,
)
del config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mvincenzo-timmel[0m ([33mi4ds_radio_sunburst_detection[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
# Get data
# Load dataframes
data_df = directory_to_dataframe()

# Filter if you want
if "instrument_to_use" in wandb.config:
    data_df = data_df[data_df.instrument.isin(wandb.config["instrument_to_use"])]

# Create datasets
train_df, test_df = get_datasets(
    data_df,
    train_size=wandb.config['train_size'],
    test_size=1-wandb.config['train_size'],
    burst_frac=wandb.config["burst_frac"],
    sort_by_time=wandb.config['sort_by_time'],
    only_unique_time_periods=True,
)

# Update datasets
val_df, test_df = (
    test_df.iloc[: len(test_df) // 2],
    test_df.iloc[len(test_df) // 2 :],
)

Class balance in train dataset:
label              burst  no_burst
instrument                        
australia_assa_02    812      7308
--------------------------------------------------
Class balance in test dataset:
label              burst  no_burst
instrument                        
australia_assa_02    348      3132
--------------------------------------------------


In [5]:
# Get model
if wandb.config["model"] == "transfer":
    mb = TransferLearningModelBuilder(model_params=wandb.config)
    # Create image generator
    ppf = lambda x: mb.preprocess_input(x, ewc=wandb.config["elim_wrong_channels"])
    datagen = ImageDataGenerator(preprocessing_function=ppf)
elif wandb.config["model"] == "autoencoder":
    mb = ModelBuilder(model_params=wandb.config['model_params'])
    datagen = ImageDataGenerator()
else:
    raise ValueError("Model not implemented.")

In [6]:
test_ds = datagen.flow_from_dataframe(
    test_df,
    x_col="file_path",
    y_col="label_keras",
    batch_size=wandb.config["batch_size"],
    seed=42,
    shuffle=False,
    class_mode="binary",
    target_size=(256, 256),
    color_mode="grayscale",
)

Found 1740 validated image filenames belonging to 2 classes.


In [7]:
# Load model from wandb
# Get the latest model of <config_name> from wandb
artifact = wandb.use_artifact(f"{entity}/{project_name}/{config_name}:latest", type="model")

# Download the model file to a desired directory
artifact_dir = artifact.download()
model_path = os.path.join(artifact_dir, "model.keras")

[34m[1mwandb[0m: Downloading large artifact transfer_4:latest, 210.90MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.7


In [8]:
# Load model 
model = load_model(model_path, safe_mode=False)

In [12]:
# Generate np array from dataset
steps = len(test_ds)  # This will give the number of batches in the test_ds
X_test = np.concatenate(
    [x for x, y in islice(test_ds, steps)]
)

In [13]:
# Sample X_test
X_test_sample = X_test[:1]

In [14]:
# Start explainer
explainer = GradCAM()
grid = explainer.explain(validation_data=(X_test_sample, None ), model=model, class_index=1)  # 1 is burst

explainer.save(grid, ".", "grad_cam.png")