# Setup

In [1]:
import dataset
import models.losses as losses
import tensorflow as tf
from models.metrics import *
import models.cnn_autoencoder_model as cnnmodel
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping

# Define mode constants (replacing tf_estimator.ModeKeys)
class ModeKeys:
    TRAIN = 'train'
    EVAL = 'eval'
    PREDICT = 'predict'

2025-11-19 23:25:50.875520: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
tf.config.list_physical_devices('GPU')

W0000 00:00:1763594753.051815   28140 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


[]

In [3]:
hparams = {
    # 数据路径
    'train_path': '../dataset/next_day_wildfire_spread_train*',
    'eval_path': '../dataset/next_day_wildfire_spread_eval*',
    'test_path': '../dataset/next_day_wildfire_spread_test*',
    
    # 特征
    'input_features': ['elevation', 'pdsi', 'NDVI', 'pr', 'sph', 'th', 'tmmn',
                  'tmmx', 'vs', 'erc', 'population', 'PrevFireMask'],
    'output_features': ['FireMask'],
    
    # 方位通道
    'azimuth_in_channel': None,
    'azimuth_out_channel': None,
    
    # 数据和模型参数
    'data_sample_size': 64,
    'sample_size': 32,
    'output_sample_size': 32,
    'batch_size': 128,
    'shuffle': False,
    'shuffle_buffer_size': 10000,
    'compression_type': None,
    'input_sequence_length': 1,
    'output_sequence_length': 1,
    'repeat': False,
    'clip_and_normalize': True,
    'clip_and_rescale': False,
    
    # 数据增强
    'random_flip': False,
    'random_rotate': False,
    'random_crop': False,
    'center_crop': True,
    
    # 其他参数
    'downsample_threshold': 0.0,
    'binarize_output': True
}

train_dataset = dataset.make_dataset(
    hparams,
    mode = ModeKeys.TRAIN
)
val_dataset = dataset.make_dataset(
    hparams,
    mode = ModeKeys.EVAL
)
test_dataset = dataset.make_dataset(
    hparams,
    mode = ModeKeys.PREDICT
)

In [4]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=30,
    restore_best_weights=True,
)

# Training

## Autoencoder

In [23]:
input_tensor = Input((32, 32, 12))
num_out_channels = 1
encoder_layers = [16,32]
decoder_layers = [32,16]
encoder_pools = [2,2]
decoder_pools = [2,2]
autoencoder_model = cnnmodel.create_model(
    input_tensor,
    num_out_channels,
    encoder_layers,
    decoder_layers,
    encoder_pools,
    decoder_pools,
)
autoencoder_model.summary()

In [24]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
autoencoder_model.compile(optimizer=optimizer,
              loss=losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=3),
              metrics=[AUCWithMaskedClass(with_logits=True)])
history = autoencoder_model.fit(train_dataset, epochs=1000, validation_data=val_dataset, callbacks=[early_stopping])

Epoch 1/1000


ValueError: `logits` and `labels` must have the same shape, received ((None, 1024) vs (None, 32, 32, 1)).

In [None]:
autoencoder_model.evaluate(test_dataset)

## ResNet

In [None]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Input
import models.model_utils as model_utils
import models.cnn_autoencoder_model as cnn_autoencoder_model
from models.cnn_autoencoder_model import decoder
from tensorflow.compat.v2 import keras


layers_list = (32, 64, 128, 256, 256)
pools_list = (2, 2, 2, 2, 2)
decoder_layers = tuple(reversed(layers_list))
decoder_pools = tuple(reversed(pools_list))
num_out_channels = 1
l1_regularization = model_utils.L1_REGULARIZATION_DEFAULT
l2_regularization = model_utils.L2_REGULARIZATION_DEFAULT

# define input
conv_input = Input(shape=(32,32,12))

# add extra convolutional layer
conv_output = tf.keras.layers.Conv2D(16, (3, 3), padding='same')(conv_input)

# define resnet encoder
keras_resnet_encoder = ResNet50(weights=None,
                 include_top=False,
                input_shape=(32, 32, 16))

encoder_output = keras_resnet_encoder(conv_output)

# define resnet decoder
# decoder_input_img = Input(shape=keras_resnet_encoder.output_shape[1:])

x = decoder(encoder_output, decoder_layers, decoder_pools)
decoder_output = model_utils.conv2d_layer(
      filters=num_out_channels,
      kernel_size=model_utils.RES_SHORTCUT_KERNEL_SIZE,
      l1_regularization=l1_regularization,
      l2_regularization=l2_regularization)(x)

# keras_resnet_decoder = keras.Model(decoder_input_img, resnet_decoder)

# decoder_output = keras_resnet_decoder(encoder_output)

# define connected model
keras_model = keras.Model(inputs = conv_input, outputs = decoder_output)
keras_model.summary()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
keras_model.compile(optimizer=optimizer,
              loss=losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=3),
              metrics=[AUCWithMaskedClass(with_logits=True)])
history = keras_model.fit(train_dataset, epochs=1000, validation_data=val_dataset, callbacks=[early_stopping])

In [None]:
keras_model.evaluate(test_dataset)

## UNet

In [None]:
def expend_as(tensor, rep):
     return Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
                          arguments={'repnum': rep})(tensor)

def double_conv_layer(x, filter_size, size, dropout, batch_norm=False):
    axis = 3
    conv = SeparableConv2D(size, (filter_size, filter_size), padding='same')(x)
    if batch_norm is True:
        conv = BatchNormalization(axis=axis)(conv)
    conv = Activation('relu')(conv)
    conv = SeparableConv2D(size, (filter_size, filter_size), padding='same')(conv)
    if batch_norm is True:
        conv = BatchNormalization(axis=axis)(conv)
    conv = Activation('relu')(conv)
    if dropout > 0:
        conv = Dropout(dropout)(conv)

    shortcut = Conv2D(size, kernel_size=(1, 1), padding='same')(x)
    if batch_norm is True:
        shortcut = BatchNormalization(axis=axis)(shortcut)

    res_path = add([shortcut, conv])
    return res_path

def encoder(inputs):
    num_filters = [16, 32, 64, 128]
    skip_connections = []
    x = inputs

    for i, f in enumerate(num_filters):
        a = double_conv_layer(x, 3, f, 0.1, True)
        skip_connections.append(a)
        x = MaxPooling2D(pool_size=(2, 2))(a)
    
    return x, skip_connections

def bottleneck(inputs):
    x = inputs
    f = 256
    
    x3 = double_conv_layer(x, 3, f, 0.1, True)
    
    return x3

def decoder(inputs, skip_connections):
    num_filters = [128, 64, 32, 16]
    skip_connections.reverse()
    x = inputs
    batch_norm = True
    
    for i, f in enumerate(num_filters):
        
        x_up = UpSampling2D(size=(2, 2), data_format="channels_last")(x)
        x_att = concatenate([x_up, skip_connections[i]], axis=-1)
        
        x = double_conv_layer(x_att, 3, f, 0.1, True)
    return x

def output(inputs):
    x = Conv2D(1, kernel_size=(1,1))(inputs)
    x = BatchNormalization()(x)
    # x = Activation('sigmoid')(x)
    
    return x

inputs = Input((32, 32, 12))
# s = layers.experimental.preprocessing.Rescaling(1.0 / 255)(inputs)
s = inputs
x, skip_1 = encoder(s)
x = bottleneck(x)
x = decoder(x, skip_1)
outputs = output(x)
unet_model = Model(inputs, outputs)
unet_model.summary()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
unet_model.compile(optimizer=optimizer,
              loss=losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=2),
              metrics=[AUCWithMaskedClass(with_logits=True)])
history = unet_model.fit(train_dataset, epochs=1000, validation_data=val_dataset, callbacks=[early_stopping])

In [None]:
unet_model.evaluate(test_dataset)

## ViT

In [5]:
import numpy as np
from glob import glob

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv2D, concatenate

import keras_vision_transformer as create_swin_unet
from keras_vision_transformer.create_swin_unet import swin_unet_2d_base

import keras_vision_transformer.create_swin_unet_add_convolution
from keras_vision_transformer.create_swin_unet_add_convolution import swin_unet_2d_base

filter_num_begin = 128     # number of channels in the first downsampling block; it is also the number of embedded dimensions
depth = 4                  # the depth of SwinUNET; depth=4 means three down/upsampling levels and a bottom level
stack_num_down = 2         # number of Swin Transformers per downsampling level
stack_num_up = 2           # number of Swin Transformers per upsampling level
patch_size = (2, 2)        # Extract 4-by-4 patches from the input image. Height and width of the patch must be equal.
num_heads = [4, 8, 8, 8]   # number of attention heads per down/upsampling level
window_size = [4, 2, 2, 2] # the size of attention window per down/upsampling level
num_mlp = 512              # number of MLP nodes within the Transformer
shift_window=True          # Apply window shifting, i.e., Swin-MSA

# define input size
input_size = (32,32,12)
IN = Input(input_size)

# Base architecture
X = swin_unet_2d_base(IN, filter_num_begin, depth, stack_num_down, stack_num_up,
                      patch_size, num_heads, window_size, num_mlp,
                      shift_window=shift_window, name='swin_unet')

# define output: remove activation function
n_labels = 1
OUT = Conv2D(n_labels, kernel_size=1, use_bias=False, padding='same')(X)

# Model Configuration
keras_model = Model(inputs=IN, outputs=OUT)

keras_model.summary()



In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
keras_model.compile(optimizer=optimizer,
              loss=losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=4),
              metrics=[AUCWithMaskedClass(with_logits=True)])
history = keras_model.fit(train_dataset, epochs=1000, validation_data=val_dataset, callbacks=[early_stopping])

Epoch 1/1000


2025-11-19 23:26:34.933193: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:390] TFRecordDataset `buffer_size` is unspecified, default to 262144


      3/Unknown [1m142s[0m 41s/step - auc_with_masked_class: 0.0272 - loss: 1.2897

In [None]:
keras_model.evaluate(test_dataset)