# Data loading and tfrecords creation
Make sure to connect to GPU P100 if on kaggle environment

In [None]:
"""
Only relevant if performing ablation for BERT pre-training, 
make sure to run it at the begining and the change hyperparamter configuration in 'main pretrained bock' and run 'plotting ablation result block'
"""
all_histories = {} 

In [1]:
import tensorflow as tf,keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import math,random,time,shutil,tempfile,glob,os,json,itertools
from pathlib import Path
from tensorflow.keras import layers, models, callbacks, mixed_precision,Model
from tensorflow.keras.initializers import HeNormal, GlorotUniform
from typing import Optional
import tensorflow.keras.metrics as metrics
from sklearn.model_selection import train_test_split
from collections import defaultdict
from typing import List,Dict,Any
from sklearn.metrics import f1_score, classification_report, hamming_loss


PARQUET_DIR = "/kaggle/input/tno-parquet-latest2"
TFRECORD_DIR = "tfrecords_highd"
processed_dir = "/kaggle/working/"
SEQ_SIZES = [64, 128, 192]  # ~2s, ~5s, ~7s
BATCH_SIZES = [256, 128, 64]
FEATURES = [str(i) for i in range(0,20)]
NUM_FEATURES = len(FEATURES)
TARGET_SHARD_MEMORY_MB = 20
BYTES_PER_FLOAT = 4
MASK_PROB = 0.9
AUTOTUNE = tf.data.AUTOTUNE
os.environ['PYTHONHASHSEED'] = '42'
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)
os.makedirs(TFRECORD_DIR, exist_ok=True) 
mixed_precision.set_global_policy("mixed_float16")
# mixed_precision.set_global_policy("float32")

CONTAINER_LEN = None  # Must match pre-trained BERT input size
ACTIVE_LEN = None  
LABEL_COLS=None
NUM_LABELS = None
STRIDE = None

"""
Very important!!
set the below parameter as per your expectations
"""
USE_ABLATION, USE_FINE_TUNE, testing, generate_file = False, False, False, False



def TFRecords_Creation(PARQUET_DIR,stride_ratio):
    global CONTAINER_LEN,ACTIVE_LEN, LABEL_COLS, NUM_LABELS,STRIDE
    if testing==False and USE_FINE_TUNE == False:
        parquet_files = sorted(glob.glob(os.path.join(PARQUET_DIR, "*.parquet")))
        train_files = parquet_files
        if USE_ABLATION==False:
            split_to_files = {
                "train": train_files
            }
            all_sequences_by_split = {size: {"train": []} for size in SEQ_SIZES}
        else:
            split_to_files = {
                "train": train_files[:43],
                "val": train_files[43:48],
                "test": train_files[48:]
            }
            all_sequences_by_split = {size: {"train": [],"val": [],"test": []} for size in SEQ_SIZES}
        
        for split, files in split_to_files.items():
            for parquet_file in files:
                df = pd.read_parquet(parquet_file)
                arrays = {f: df[f].to_numpy() for f in FEATURES}
                for size in SEQ_SIZES:
                    stride = int(size*stride_ratio)
                    for start_idx in range(0, len(df) - size + 1, stride):
                        seq = np.stack([arrays[f][start_idx:start_idx+size] for f in FEATURES], axis=-1)
                        all_sequences_by_split[size][split].append(seq)
        
        print("Collected sequences per split.")
        
        # compute shard_info per split and size, then write into subfolders
        def serialize_example(sequence):
            feature = {
                "features": tf.train.Feature(float_list=tf.train.FloatList(value=sequence.flatten()))
            }
            example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
            return example_proto.SerializeToString()
        
        for size in SEQ_SIZES:
            for split in split_to_files.keys():
                sequences = all_sequences_by_split[size][split]
                total_sequences = len(sequences)
                seq_memory_bytes = size * NUM_FEATURES * BYTES_PER_FLOAT
                target_shard_bytes = TARGET_SHARD_MEMORY_MB * 1024 * 1024
                seq_per_shard = max(1, int(target_shard_bytes / seq_memory_bytes))
                num_shards = math.ceil(total_sequences / seq_per_shard)
        
                out_dir = os.path.join(TFRECORD_DIR, split)
                os.makedirs(out_dir, exist_ok=True)
        
                for shard_id in range(num_shards):
                    shard_filename = os.path.join(out_dir, f"{split}_seq{size}_{shard_id:03d}.tfrecord")
                    start = shard_id * seq_per_shard
                    end = min((shard_id + 1) * seq_per_shard, total_sequences)
                    with tf.io.TFRecordWriter(shard_filename) as writer:
                        for seq in sequences[start:end]:
                            writer.write(serialize_example(seq))
                print(f"Wrote {num_shards} shards for split={split}, seq_len={size}, total_seqs={total_sequences}")
    
    else:
        def _bytes_feature(value):
            """Returns a bytes_list from a string / byte."""
            # Ensure value is encoded as bytes if it's a string
            if isinstance(value, str):
                value = value.encode('utf-8')
            return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
        
        def _float_feature(value):
            """Returns a float_list from a float or list of floats."""
            return tf.train.Feature(float_list=tf.train.FloatList(value=value))
        
        def serialize(x, y, mask, doc_id, veh_id):
            """
            Serializes sequence data (x, y, mask) and string IDs (doc_id, veh_id)
            into a tf.train.Example.
            """
            return tf.train.Example(features=tf.train.Features(feature={
                # Sequence data (float_list)
                "x": _float_feature(x.flatten()),
                "y": _float_feature(y.flatten()),
                "mask": _float_feature(mask.flatten()),
                # ID data (bytes_list)
                "doc_id": _bytes_feature(doc_id),
                "veh_id": _bytes_feature(veh_id),
            })).SerializeToString()
        
        def write_split(groups, split_name):
            """
            Processes groups of DataFrame rows, segments them into sequences,
            and writes them to a TFRecord file, including document and vehicle IDs.
            """
            # Lists to hold sequence data
            sequences, labels, masks = [], [], []
            # Lists to hold ID data (used for verification/logging if needed, but primarily
            # passed to serialize() now)
            doc_ids_new, veh_ids_new = [], []
        
            container_len, active_len, stride = CONTAINER_LEN, ACTIVE_LEN, STRIDE
        
            for (d, v), group_df in groups:
                # 'd' is document_id, 'v' is vehicle_id
                # group_df is the DataFrame slice for that specific (doc, veh) pair
        
                # Extract features and labels
                X = group_df[FEATURES].to_numpy().astype("float32")
                Y = group_df[LABEL_COLS].to_numpy().astype("float32")
        
                T = len(group_df)
                static_pad = container_len - active_len
        
                # --- Logic for testing==True (25-frame chunks) ---
                if testing == True:
                    X1, Y1 = [], []
                    for start in range(0, T - 25 + 1, 25):
                        X1.extend(X[start : start + 25])
                        Y1.extend(Y[start : start + 25])
        
                    len_X1 = len(X1)
        
                    # Handle sequences shorter than active_len
                    if len_X1 < active_len:
                        active_pad = active_len - len_X1
                        total_pad = active_pad + static_pad
        
                        x = np.pad(X1, ((0, total_pad), (0, 0)), mode='constant', constant_values=0.0)
                        y = np.pad(Y1, ((0, total_pad), (0, 0)), mode='constant', constant_values=0.0)
                        mask = np.concatenate([np.ones(len_X1, dtype=np.float32), np.zeros(total_pad, dtype=np.float32)])
        
                        sequences.append(x)
                        labels.append(y)
                        masks.append(mask)
                        doc_ids_new.append(d)
                        veh_ids_new.append(v)
        
                    # Handle sequences longer than active_len (sliding window)
                    else:
                        X1 = np.array(X1)
                        Y1 = np.array(Y1)
                        
                        # Check for padding needed to align with stride for the last window
                        remaining = (len_X1 - active_len) % stride
                        if remaining != 0:
                            s = stride - remaining
                            X1 = np.pad(X1, ((0, s), (0, 0)), mode='constant', constant_values=0.0)
                            Y1 = np.pad(Y1, ((0, s), (0, 0)), mode='constant', constant_values=0.0)
                        else:
                            s = 0
        
                        num_windows = (len(X1) - active_len) // stride + 1
        
                        for count, start in enumerate(range(0, len(X1) - active_len + 1, stride)):
                            x_slice = X1[start : start + active_len]
                            y_slice = Y1[start : start + active_len]
        
                            x_final = np.pad(x_slice, ((0, static_pad), (0, 0)), mode='constant', constant_values=0.0)
                            y_final = np.pad(y_slice, ((0, static_pad), (0, 0)), mode='constant', constant_values=0.0)
        
                            # Determine mask based on padding for the last window
                            if count == num_windows - 1 and remaining != 0:
                                mask_final = np.concatenate([np.ones(active_len - s, dtype=np.float32), np.zeros(static_pad + s, dtype=np.float32)])
                            else:
                                mask_final = np.concatenate([np.ones(active_len, dtype=np.float32), np.zeros(static_pad, dtype=np.float32)])
        
                            sequences.append(x_final)
                            labels.append(y_final)
                            masks.append(mask_final)
                            doc_ids_new.append(d)
                            veh_ids_new.append(v)
        
                # --- Logic for testing==False (full group processing) ---
                else: # training/validation logic
                    len_X = len(X)
                    # Handle groups shorter than active_len
                    if len_X < active_len:
                        active_pad = active_len - len_X
                        total_pad = active_pad + static_pad
        
                        x = np.pad(X, ((0, total_pad), (0, 0)), mode='constant', constant_values=0.0)
                        y = np.pad(Y, ((0, total_pad), (0, 0)), mode='constant', constant_values=0.0)
                        mask = np.concatenate([np.ones(len_X, dtype=np.float32), np.zeros(total_pad, dtype=np.float32)])
        
                        sequences.append(x)
                        labels.append(y)
                        masks.append(mask)
                        doc_ids_new.append(d)
                        veh_ids_new.append(v)
                    
                    # Handle groups longer than active_len (sliding window)
                    else:
                        # Check for padding needed to align with stride for the last window
                        remaining = (len_X - active_len) % stride
                        if remaining != 0:
                            s = stride - remaining
                            X = np.pad(X, ((0, s), (0, 0)), mode='constant', constant_values=0.0)
                            Y = np.pad(Y, ((0, s), (0, 0)), mode='constant', constant_values=0.0)
                        else:
                            s = 0
        
                        num_windows = (len(X) - active_len) // stride + 1
        
                        for count, start in enumerate(range(0, len(X) - active_len + 1, stride)):
                            x_slice = X[start : start + active_len]
                            y_slice = Y[start : start + active_len]
        
                            x_final = np.pad(x_slice, ((0, static_pad), (0, 0)), mode='constant', constant_values=0.0)
                            y_final = np.pad(y_slice, ((0, static_pad), (0, 0)), mode='constant', constant_values=0.0)
                            
                            # Determine mask based on padding for the last window
                            if count == num_windows - 1 and remaining != 0:
                                mask_final = np.concatenate([np.ones(active_len - s, dtype=np.float32), np.zeros(static_pad + s, dtype=np.float32)])
                            else:
                                mask_final = np.concatenate([np.ones(active_len, dtype=np.float32), np.zeros(static_pad, dtype=np.float32)])
        
                            sequences.append(x_final)
                            labels.append(y_final)
                            masks.append(mask_final)
                            doc_ids_new.append(d)
                            veh_ids_new.append(v)
        
        
            out_file = f"{TFRECORD_DIR}/{split_name}.tfrecord"
            tf.io.gfile.makedirs(TFRECORD_DIR) # Ensure directory exists
            
            with tf.io.TFRecordWriter(out_file) as w:
                # Pass the IDs to the serialize function along with the sequence data
                for x, y, m, d, v in zip(sequences, labels, masks, doc_ids_new, veh_ids_new):
                    doc_id_safe = str(d)
                    veh_id_safe = str(v)
                    
                    # Pass the corrected string IDs to the serialize function
                    w.write(serialize(x, y, m, doc_id_safe, veh_id_safe))
                    
            # The return signature now depends on the split_name (or testing mode)
            if split_name == "test" or testing == True:
                return out_file, doc_ids_new, veh_ids_new
            else:
                return out_file
        
        # --- Data Grouping and Splitting Logic ---
        # Note: The groups list must now contain (key, DataFrame) tuples
        groups = []
        
        if testing==True:
            df = pd.read_parquet(PARQUET_DIR[1])
            CONTAINER_LEN = 192  # Must match pre-trained BERT input size
            ACTIVE_LEN = 192  
            LABEL_COLS=df.columns[23:].tolist()
            NUM_LABELS = len(LABEL_COLS)
            STRIDE = int(ACTIVE_LEN * 1.0)
            # IMPORTANT CORRECTION: Append the key along with the DataFrame
            for key, g in df.groupby(["document_id", "id"], sort=False):
                if len(g)<25:
                    continue
                groups.append((key, g.sort_values("frame")))
                
            # Correctly unpack the multiple return values
            test_tfr, doc_ids_new, veh_ids_new = write_split(groups, "test")
            print(f"Test TFRecord written to: {test_tfr}")
            print(f"Total sequences written: {len(doc_ids_new)}")
            # print(f"Example Doc IDs: {doc_ids_new[:5]}") # Print for verification
            # print(f"Example Veh IDs: {veh_ids_new[:5]}") # Print for verification
        
        
        if testing==False:
            df = pd.read_parquet(PARQUET_DIR[0])
            
            
            CONTAINER_LEN = 192  # Must match pre-trained BERT input size
            ACTIVE_LEN = 192  
            LABEL_COLS=df.columns[23:].tolist()
            NUM_LABELS = len(LABEL_COLS)
            STRIDE = int(ACTIVE_LEN * stride_ratio)
            for key, g in df.groupby(["document_id", "id"], sort=False):
                groups.append((key, g.sort_values("frame")))
        
            # 2. Perform the Split (Exact match to BiLSTM logic)
            # Step A: Split off 15% for TEST
            train_val_groups, test_groups = train_test_split(
                groups, test_size=0.15, random_state=42
            )
            
            # Step B: Split the remaining 85% into TRAIN (70% total) and VAL (15% total)
            val_relative_size = 0.15 / 0.85
            
            train_groups, val_groups = train_test_split(
                train_val_groups, test_size=val_relative_size, random_state=42
            )
            
            print(f"Train groups: {len(train_groups)} (~{len(train_groups)/len(groups):.1%})")
            print(f"Val groups:    {len(val_groups)} (~{len(val_groups)/len(groups):.1%})")
            print(f"Test groups:   {len(test_groups)} (~{len(test_groups)/len(groups):.1%})")
        
            # The write_split for train/val returns only the file path
            train_tfr = write_split(train_groups, "train")
            val_tfr   = write_split(val_groups, "val")
            test_tfr, _, _ = write_split(test_groups, "test") # Ignore IDs for train/val/test split mode
            
            print(f"TFRecord files written: Train: {train_tfr}, Val: {val_tfr}, Test: {test_tfr}")
        
        print("TFRecord generation complete.")

if generate_file==True:
    if testing ==False and USE_FINE_TUNE == False:
        TFRecords_Creation("/kaggle/input/tno-parquet-latest2", 0.8)
        shutil.make_archive("/kaggle/working/tfrecords_highd", "zip", processed_dir)
        
    elif USE_FINE_TUNE == True:
        TFRecords_Creation(["/kaggle/input/fine-tuning-preprocessed/fine_tuning_preprocessed.parquet","/kaggle/input/testing-preprocessed/testing_preprocessed.parquet"], 0.8)
        shutil.make_archive("/kaggle/working/tfrecords_highd", "zip", processed_dir)

# BERT pre-training 

## Main pre-training block: tfrecords pre-processing and BERT defination

In [None]:
if USE_FINE_TUNE==False:
    if USE_ABLATION:
        MAX_SEQ_LEN = 64
    else:
        MAX_SEQ_LEN = max(SEQ_SIZES)
    
    def parse_and_pad_example(example_proto, seq_len):
        feature_description = {"features": tf.io.FixedLenFeature([seq_len * NUM_FEATURES], tf.float32)}
        parsed = tf.io.parse_single_example(example_proto, feature_description)
        seq = tf.reshape(parsed["features"], (seq_len, NUM_FEATURES))
        pad_amount = MAX_SEQ_LEN - seq_len  # Python int
        if pad_amount < 0:
            raise ValueError(f"seq_len {seq_len} > MAX_SEQ_LEN {MAX_SEQ_LEN}")
        seq = tf.pad(seq, [[0, pad_amount], [0, 0]])
        valid_mask = tf.concat([tf.ones(seq_len, dtype=tf.float32), tf.zeros(pad_amount, dtype=tf.float32)], axis=0)
        return seq, valid_mask
    
    def make_base_dataset(seq_len, batch_size, subset):
        """
        Creates a batched and prefetched tf.data.Dataset for the given subset.
        Handles both full and ablation modes.
        """
        if subset=="train":
            tfrecord_files = sorted(glob.glob(os.path.join(f"/kaggle/input/tfrecords20-{subset}2", f"{subset}_seq{seq_len}_*.tfrecord")))
        else:
            tfrecord_files = sorted(glob.glob(os.path.join(f"/kaggle/input/tfrecords-{subset}", f"{subset}_seq{seq_len}_*.tfrecord")))
            
        if USE_ABLATION:
            if subset=="train":
                print("train")
                tfrecord_files = [os.path.join(f"/kaggle/input/tfrecords-{subset}", f"{subset}_seq{seq_len}_{i:03d}.tfrecord") for i in range(0, 50)]
            elif subset=="val":
                print("val")
                tfrecord_files = [os.path.join(f"/kaggle/input/tfrecords-{subset}", f"{subset}_seq{seq_len}_{i:03d}.tfrecord") for i in range(0, 5)]
            else:
                print("test")
                tfrecord_files = [os.path.join(f"/kaggle/input/tfrecords-{subset}", f"{subset}_seq{seq_len}_{i:03d}.tfrecord") for i in range(0, 5)]
    
        # Build dataset
        ds = tf.data.TFRecordDataset(tfrecord_files, num_parallel_reads=AUTOTUNE)
        ds = ds.map(lambda x: parse_and_pad_example(x, seq_len), num_parallel_calls=AUTOTUNE)
        ds = ds.shuffle(2000, seed=42).batch(batch_size, drop_remainder=True).prefetch(AUTOTUNE)
        return ds
    
    # Masked Frame Modeling applied on batches (vectorized)
    def apply_mfm_on_batch(batch_seq, padding_mask):
        """
        batch_seq: [B, MAX_SEQ_LEN, F]
        padding_mask: [B, MAX_SEQ_LEN] float
        """
        batch_size = tf.shape(batch_seq)[0]
        
        # mask_prob=MASK_PROB if subset=="train" else 0.3
        mask_prob=MASK_PROB
        # random mask where valid
        rnd = tf.random.uniform((batch_size, MAX_SEQ_LEN))
        mfm_mask = tf.cast(rnd < mask_prob, tf.float32) * padding_mask
    
        # expand to feature dimension
        mfm_mask_feat = tf.expand_dims(mfm_mask, -1)
    
        # masked input & target
        x_masked = batch_seq * (1.0 - mfm_mask_feat)
        y_target = batch_seq * mfm_mask_feat
    
        # x_masked = tf.clip_by_value(x_masked, -5.0, 5.0)
        # y_target = tf.clip_by_value(y_target, -5.0, 5.0)
        sample_weight = mfm_mask
        # print(mfm_mask.shape)
        return (x_masked, padding_mask), y_target, sample_weight
    
    # Build per-size datasets (these produce (inputs, y, sample_weight) tuples eventually)
    datasets_by_size = {}
    if USE_ABLATION:
        # üîí Only use seq_len = 64 for ablation study
        print(f"üß™ Ablation mode active: using only seq_len={MAX_SEQ_LEN} for all subsets")
        train_ds = make_base_dataset(MAX_SEQ_LEN, BATCH_SIZES[0], "train").map(lambda seqs, pads: apply_mfm_on_batch(seqs, pads), num_parallel_calls=AUTOTUNE)
        val_ds = make_base_dataset(MAX_SEQ_LEN, BATCH_SIZES[0], "val").map(lambda seqs, pads: apply_mfm_on_batch(seqs, pads), num_parallel_calls=AUTOTUNE)
        datasets_by_size = {"train": train_ds, "val": val_ds}
    else:
        for size, batch_size in zip(SEQ_SIZES, BATCH_SIZES):
            # yields (seq_padded, padding_mask)
            ds_train = make_base_dataset(size, batch_size, subset="train").map(lambda seqs, pads: apply_mfm_on_batch(seqs, pads), num_parallel_calls=AUTOTUNE)
            datasets_by_size[size] = {"train": ds_train} 
    
    # --------------------------
    # Custom Positional Embedding Layer
    # --------------------------
    class PositionalEmbedding(layers.Layer):
        """
        Adds learned positional embeddings to the input sequence.
        """
        def __init__(self, max_seq_len, d_model, **kwargs):
            super().__init__(**kwargs)
            self.max_seq_len = max_seq_len
            self.d_model = d_model
            self.pos_emb = layers.Embedding(input_dim=max_seq_len, output_dim=d_model,dtype="float32")
    
        def call(self, x):
            # Create position indices dynamically based on sequence length
            seq_len = tf.shape(x)[1]
            positions = tf.range(start=0, limit=seq_len, delta=1)
            pos_embeddings = tf.cast(self.pos_emb(positions),x.dtype)
            return x + pos_embeddings
    
        def get_config(self):
            config = super().get_config()
            config.update({
                "max_seq_len": self.max_seq_len,
                "d_model": self.d_model,
            })
            return config
    
    def padfloat_to_attnbool(m):
        return tf.expand_dims(tf.expand_dims(tf.cast(m > 0.5, tf.bool), 1), 1)
    
    # --------------------------
    # Model Builder Function
    # --------------------------
    # def build_model(max_seq_len=MAX_SEQ_LEN, num_features=NUM_FEATURES, d_model=64, num_heads=4, num_layers=2, dropout_rate=0.0):
    
    def build_model(max_seq_len=MAX_SEQ_LEN, num_features=NUM_FEATURES, d_model=256, num_heads=8, num_layers=8, dropout_rate=0.1):
        """
        Builds a BERT-like Transformer model for vehicle trajectory pretraining,
        incorporating dropout for regularization.
        """
        # === Inputs ===
        traj_in = tf.keras.Input(shape=(max_seq_len, num_features),dtype=tf.float32, name="trajectory")
        pad_mask_in = tf.keras.Input(shape=(max_seq_len,), dtype=tf.float32, name="pad_mask")
    
        # === Input projection ===
        x = layers.LayerNormalization(epsilon=1e-6, dtype="float32",name="input_ln")(traj_in)
        x = layers.Dense(d_model, kernel_initializer=GlorotUniform(seed=42), name="input_proj")(x)
        x = layers.Activation("gelu")(x)
        x = layers.Dropout(dropout_rate, name="input_dropout")(x) # <--- ADDED DROPOUT
        # === Add positional embeddings ===
        x = PositionalEmbedding(max_seq_len, d_model, name="positional_embedding")(x)
        
        # === Prepare attention mask (wrapped in Keras Lambda for compatibility) ===
        attn_mask = layers.Lambda(padfloat_to_attnbool, name="padfloat_to_attnbool")(pad_mask_in)
        
        # === Transformer encoder blocks (Pre-Normalization) ===
        for i in range(num_layers):
            # 1. Attention Block
            norm1 = layers.LayerNormalization(epsilon=1e-6,dtype="float32", name=f"attn_norm_{i}")(x) # Pre-Normalization
            attn_out = layers.MultiHeadAttention(
                num_heads=num_heads,
                key_dim=d_model // num_heads,
                name=f"mha_{i}",
                use_bias=False,
                dtype="float32"
            )(norm1, norm1, attention_mask= attn_mask)
            # Add a projection layer if needed (as in your original, though often omitted)
            attn_out = layers.Dense(d_model, kernel_initializer=GlorotUniform(seed=42),dtype="float32")(attn_out)
            attn_out = layers.Dropout(dropout_rate, name=f"attn_dropout_{i}")(attn_out) # <--- ADDED DROPOUT
            
            # Residual connection
            x = layers.Add(name=f"attn_residual_{i}")([x, attn_out])
            
            # 2. Feedforward Block
            norm2 = layers.LayerNormalization(epsilon=1e-6,dtype="float32", name=f"ffn_norm_{i}")(x) # Pre-Normalization
            # ffn = layers.Dense(d_model * 4, activation="gelu", kernel_initializer=GlorotUniform(seed=42), name=f"ffn_dense1_{i}")(norm2)
            # ffn = layers.Dense(d_model, kernel_initializer=GlorotUniform(seed=42), name=f"ffn_dense2_{i}")(ffn)
            gate = layers.Dense(int(d_model * 8 / 3) , activation="swish", kernel_initializer=GlorotUniform(seed=42), name=f"ffn_gate_{i}",dtype="float32")(norm2) 
            linear = layers.Dense(int(d_model * 8 / 3) , kernel_initializer=GlorotUniform(seed=42), name=f"ffn_linear_{i}",dtype="float32")(norm2) 
            x_gated = layers.Multiply(name=f"ffn_multiply_{i}")([gate, linear]) 
            ffn = layers.Dense(d_model,kernel_initializer=GlorotUniform(seed=42), name=f"ffn_dense_out_{i}", dtype="float32")(x_gated)
            
            ffn = layers.Dropout(dropout_rate, name=f"ffn_dropout_{i}")(ffn) # <--- ADDED DROPOUTm
    
            # Residual
            x = layers.Add(name=f"ffn_residual_{i}")([x, ffn])
            
        # === Final clean-up layer (optional, but standard for Pre-Norm) ===
        # x = layers.LayerNormalization(epsilon=1e-6,dtype="float32", name="final_norm")(x) # <--- FINAL NORM
    
        # === Output projection ===
        # Project back to original feature dimension (num_features) for the MLM task
        outputs = layers.Dense(num_features, dtype="float32", kernel_initializer=GlorotUniform(seed=42), name="output_projection")(x)
    
        # === Build model ===
        model = Model(inputs=[traj_in, pad_mask_in], outputs=outputs, name="TrajectoryBERT")
        print("‚úÖ Model built successfully")
        return model
        
    curriculum_stages = [
        ([64], [1.0]),
        ([64, 128], [0.7, 0.3]),
        ([64, 128, 192], [0.5, 0.3, 0.2])
    ]
    
    
    # ============================================================
    # === Build model and global Warmup + Cosine LR schedule ===
    # ============================================================
    # with strategy.scope():
    model = build_model()
    EPOCHS = 12
    
    def calculate_stage_steps_per_epoch(seqs, weights):
        """Calculates the weighted steps_per_epoch for a single stage."""
        # This calculation is correct based on your existing code logic:
        return int(sum(train_steps[s] * w for s, w in zip(seqs, weights)))
    
    class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, base_lr, total_steps, warmup_steps, min_lr=1e-6):
            super().__init__()
            self.base_lr = base_lr
            self.warmup_steps = warmup_steps
            self.total_steps = total_steps
            self.min_lr = min_lr
    
        def __call__(self, step):
            # Linear warmup
            warmup_lr = self.base_lr * (tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
            
            # Cosine decay after warmup
            progress = (tf.cast(step - self.warmup_steps, tf.float32) /
                        tf.cast(self.total_steps - self.warmup_steps, tf.float32))
            cosine_decay = 0.5 * (1 + tf.cos(np.pi * tf.clip_by_value(progress, 0.0, 1.0)))
            cosine_lr = self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
            
            return tf.cond(step < self.warmup_steps, lambda: warmup_lr, lambda: cosine_lr)
            
        def get_config(self): # <--- THIS IS THE REQUIRED FIX
            """Returns the serializable configuration of the schedule."""
            return {
                "base_lr": self.base_lr,
                "total_steps": self.total_steps,
                "warmup_steps": self.warmup_steps,
                "min_lr": self.min_lr,
            }
    
        # Optional: Keras automatically handles from_config if get_config returns
        # arguments matching the __init__ signature, but defining it is safer.
        @classmethod
        def from_config(cls, config):
            return cls(**config)
    
    def get_lr_schedule(stage_steps):
        """Create a new Warmup + Cosine Decay schedule per stage."""
        warmup_steps = int(0.1 * stage_steps)
        # warmup_steps = 4000
        return WarmUpCosine(
            base_lr=1e-3,
            total_steps=stage_steps,
            warmup_steps=warmup_steps,
            min_lr=1e-6
        )
    
    if USE_ABLATION:
        print("üß™ Ablation mode active: static training setup")
        total_steps = EPOCHS * sum(1 for _ in datasets_by_size["train"])
        lr_schedule = get_lr_schedule(total_steps)
        model.compile(optimizer=mixed_precision.LossScaleOptimizer(tf.keras.optimizers.AdamW(lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9, clipnorm=1.0)), loss="huber", metrics=["mae"])
    else:
        x,y,z = sum(1 for _ in datasets_by_size[64]['train']),sum(1 for _ in datasets_by_size[128]['train']),sum(1 for _ in datasets_by_size[192]['train'])
        # x1,y1,z1 = sum(1 for _ in datasets_by_size[64]['val']),sum(1 for _ in datasets_by_size[128]['val']),sum(1 for _ in datasets_by_size[192]['val'])
        train_steps = {64: x, 128: y, 192: z}
        # val_steps = {64: x1, 128: y1, 192: z1}
        # Calculate the steps for all three stages:
        total_calculated_steps = 0
        for seq_sizes, weights in curriculum_stages:
            stage_steps_per_epoch = calculate_stage_steps_per_epoch(seq_sizes, weights)
            total_calculated_steps += stage_steps_per_epoch * EPOCHS
        print("üéì Curriculum Learning mode active")
        lr_schedule = get_lr_schedule(total_calculated_steps)
        optimizer = mixed_precision.LossScaleOptimizer(tf.keras.optimizers.AdamW(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9, clipnorm=1.0))
        model.compile(optimizer=optimizer, loss="huber", metrics=["mae"])
    
    # print("trainable_weights",sum(w.numpy().size for w in model.trainable_weights))

## Creating pre-trained model checkpoint
Make sure if training discontinues, then it resumes from the last checkpoint instead of starting from epoch 1. It will only work if the checkpoint is downloaded from 'kaggle/working/bert' and then uploaded to 'kaggle/input/bert-1'.

In [None]:
if USE_FINE_TUNE==False:
    from tensorflow.keras.callbacks import Callback
    # Required custom objects for model/optimizer deserialization
    custom_objects = {
        "PositionalEmbedding": PositionalEmbedding,
        "WarmUpCosine": WarmUpCosine,
        "padfloat_to_attnbool": padfloat_to_attnbool,
        "LossScaleOptimizer": tf.keras.mixed_precision.LossScaleOptimizer,
    }
    
    class TimeBasedEpochCheckpoint(callbacks.Callback):
        """
        Saves at epoch end (atomic), writes metadata (.meta.json), and keeps only the latest `last_n_checkpoints`.
        Robust pruning: handles checkpoint artifacts that may be directories OR single files.
        """
        def __init__(self, basepath, interval_minutes=0.0, last_n_checkpoints=5, verbose=1):
            super().__init__()
            self.basepath = basepath
            self.base_dir = os.path.dirname(basepath) or "."
            os.makedirs(self.base_dir, exist_ok=True)
            self.interval_seconds = float(interval_minutes) * 60.0
            self.last_n_checkpoints = int(last_n_checkpoints)
            self.verbose = int(verbose)
            self.last_save_time = None
            self.stage_idx = 0
            self.base_basename = os.path.basename(basepath)
    
        def set_stage(self, stage_idx: int):
            self.stage_idx = int(stage_idx)
    
        def on_train_begin(self, logs=None):
            self.last_save_time = time.time() - (self.interval_seconds + 1.0)
    
        def on_epoch_end(self, epoch, logs=None):
            now = time.time()
            if self.interval_seconds > 0 and (now - (self.last_save_time or 0.0)) < self.interval_seconds:
                if self.verbose:
                    print(f"[TimeBasedEpochCheckpoint] skipping save at epoch {epoch} (interval not reached)")
                return
    
            ts = int(now)
            ckpt_name = f"{self.base_basename}_stage{self.stage_idx}_e{epoch}_t{ts}.keras"
            ckpt_path_target = os.path.join(self.base_dir, ckpt_name)
            meta_path = ckpt_path_target + ".meta.json"
    
            tmp_dir = None
            try:
                tmp_dir = tempfile.mkdtemp(dir=self.base_dir)
                tmp_ckpt_path = os.path.join(tmp_dir, "model.keras")  # could be file or dir after save
                if self.verbose:
                    print(f"\n[TimeBasedEpochCheckpoint] Saving model to temporary path {tmp_ckpt_path} ...")
    
                # Save model including optimizer state
                # model.save may create either a directory or a single file at tmp_ckpt_path
                self.model.save(tmp_ckpt_path, overwrite=True, include_optimizer=True)
    
                # read optimizer.iterations for metadata
                global_step = None
                try:
                    global_step = int(self.model.optimizer.iterations.numpy())
                except Exception:
                    try:
                        global_step = int(tf.keras.backend.get_value(self.model.optimizer.iterations))
                    except Exception:
                        global_step = None
    
                metadata = {
                    "stage": int(self.stage_idx),
                    "epoch": int(epoch),
                    "global_step": global_step,
                    "timestamp": ts,
                    "ckpt_path": ckpt_path_target
                }
                tmp_meta_path = os.path.join(tmp_dir, "meta.json")
                with open(tmp_meta_path, "w") as f:
                    json.dump(metadata, f, indent=2)
    
                # Remove existing target if present (handle file/dir)
                if os.path.exists(ckpt_path_target):
                    try:
                        if os.path.isdir(ckpt_path_target):
                            shutil.rmtree(ckpt_path_target)
                        else:
                            os.remove(ckpt_path_target)
                    except Exception as e:
                        if self.verbose:
                            print(f"[TimeBasedEpochCheckpoint] Warning removing existing target {ckpt_path_target}: {e}")
    
                # Move saved artifact into final place. Works whether tmp_ckpt_path is file or dir.
                shutil.move(tmp_ckpt_path, ckpt_path_target)
                shutil.move(tmp_meta_path, meta_path)
    
                self.last_save_time = now
                if self.verbose:
                    print(f"[TimeBasedEpochCheckpoint] Saved checkpoint {ckpt_path_target} (stage={self.stage_idx}, epoch={epoch}, global_step={global_step})")
                # prune older checkpoints
                self._prune_old_checkpoints()
            finally:
                # best-effort cleanup of tmp_dir
                try:
                    if tmp_dir and os.path.exists(tmp_dir):
                        shutil.rmtree(tmp_dir)
                except Exception:
                    pass
    
        def _prune_old_checkpoints(self):
            """
            Find matching metadata files and remove the oldest checkpoint artifacts until only
            `last_n_checkpoints` remain. Each old ckpt_path may be a directory or a file.
            """
            pattern = os.path.join(self.base_dir, f"{self.base_basename}_stage*_e*_t*.keras.meta.json")
            metas = glob.glob(pattern)
            if len(metas) <= self.last_n_checkpoints:
                return
    
            # sort by modification time (oldest first)
            metas_sorted = sorted(metas, key=os.path.getmtime)
            num_to_remove = len(metas_sorted) - self.last_n_checkpoints
            to_remove = metas_sorted[:num_to_remove]
    
            for meta_path in to_remove:
                try:
                    with open(meta_path, "r") as f:
                        meta_obj = json.load(f)
                    ckpt_path = meta_obj.get("ckpt_path") or meta_obj.get("ckpt_dir")  # support older field name
                    if ckpt_path:
                        if os.path.exists(ckpt_path):
                            try:
                                if os.path.isdir(ckpt_path):
                                    shutil.rmtree(ckpt_path)
                                    if self.verbose:
                                        print(f"[TimeBasedEpochCheckpoint] Removed old checkpoint directory: {ckpt_path}")
                                elif os.path.isfile(ckpt_path):
                                    os.remove(ckpt_path)
                                    if self.verbose:
                                        print(f"[TimeBasedEpochCheckpoint] Removed old checkpoint file: {ckpt_path}")
                                else:
                                    # could be special file/symlink -> try remove
                                    try:
                                        os.remove(ckpt_path)
                                        if self.verbose:
                                            print(f"[TimeBasedEpochCheckpoint] Removed old checkpoint (unknown type): {ckpt_path}")
                                    except Exception:
                                        if self.verbose:
                                            print(f"[TimeBasedEpochCheckpoint] Could not remove unknown checkpoint path: {ckpt_path}")
                            except Exception as e:
                                if self.verbose:
                                    print(f"[TimeBasedEpochCheckpoint] Warning while removing {ckpt_path}: {e}")
                        else:
                            if self.verbose:
                                print(f"[TimeBasedEpochCheckpoint] Old checkpoint path not found (already removed?): {ckpt_path}")
                    # remove meta file
                    try:
                        os.remove(meta_path)
                    except Exception as e:
                        if self.verbose:
                            print(f"[TimeBasedEpochCheckpoint] Warning removing meta file {meta_path}: {e}")
                except Exception as e:
                    if self.verbose:
                        print(f"[TimeBasedEpochCheckpoint] Warning while pruning {meta_path}: {e}")
    
    # Initialize the epoch checkpoint callback
    CHECKPOINT_BASE = "/kaggle/working/bert"  # final checkpoint dirs will be like /kaggle/working/bert_stage{...}.keras
    time_ckpt = TimeBasedEpochCheckpoint(
        basepath=CHECKPOINT_BASE,
        interval_minutes=0.0,   # 0.0 => save at every epoch end
        last_n_checkpoints=1,
        verbose=1
    )
    
    def find_latest_meta(base_dir, basename):
        pattern = os.path.join(base_dir, f"{basename}_stage*_e*_t*.keras.meta.json")
        metas = sorted(glob.glob(pattern), key=os.path.getmtime, reverse=True)
        return metas[0] if metas else None
    
    def load_checkpoint_from_meta(meta_path):
        """Return tuple (model, meta_dict) or raise."""
        with open(meta_path, "r") as f:
            meta = json.load(f)
    
        ckpt_path = meta.get("ckpt_path") or meta.get("ckpt_dir")
        if not ckpt_path:
            raise FileNotFoundError("metadata missing ckpt_path")
    
        # NEW: Rewrite path so it loads from /kaggle/input instead of /kaggle/working
        ckpt_file = os.path.basename(ckpt_path)  # "bert_stage2_e0_tXXXX.keras"
        ckpt_path = os.path.join(os.path.dirname(meta_path), ckpt_file)
    
        print("[resume] rewritten ckpt_path =", ckpt_path)
    
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"rewritten ckpt_path does not exist: {ckpt_path}")
    
        # Try loading including optimizer
        try:
            print(f"[resume] trying tf.keras.models.load_model({ckpt_path})")
            m = tf.keras.models.load_model(
                ckpt_path, custom_objects=custom_objects, compile=True
            )
            print("[resume] load_model succeeded (full model + optimizer restored)")
            return m, meta
    
        except Exception as e:
            print(f"[resume] load_model failed: {e}. Falling back to build_model() + load_weights()")
            m = build_model()
            m.load_weights(ckpt_path)
            return m, meta
    
    
    
    
    if USE_ABLATION:
            train_history = {"loss": [], "val_loss": [],"mae":[],"val_mae":[]}
            s = time.time()
            es = callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
            # reduce_lr = callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-6, verbose=1)
            history = model.fit(
                datasets_by_size["train"],
                validation_data=datasets_by_size["val"],
                epochs=EPOCHS,
                callbacks=[es],
                verbose=1
            )
    
            train_history["loss"].extend(history.history["loss"])
            train_history["val_loss"].extend(history.history["val_loss"])
            train_history["mae"].extend(history.history["mae"])
            train_history["val_mae"].extend(history.history["val_mae"])
            all_histories['GeLU'] = history.history
            e = time.time()
            print(f"tfrecord finished. Elapsed time: {(e-s)/60:.2f} minutes")
        
    else:
        # ---------------------------
        # Attempt to resume from last checkpoint (if any)
        # ---------------------------
        BASE_DIR = "/kaggle/input/bert-1"      # (same variable you use later) 
        BASENAME = "bert"
        latest_meta = find_latest_meta(BASE_DIR, BASENAME)
        
        resume_info = None
        if latest_meta is not None:
            print("[resume] Found latest meta:", latest_meta)
            try:
                model, meta = load_checkpoint_from_meta(latest_meta)
                # ensure model is compiled (if load_model returned compile=True it is compiled)
                # If not compiled, compile with the same optimizer configuration used originally.
                if model.optimizer is None:
                    print("[resume] model was not compiled; compiling manually")
                    lr_sched = get_lr_schedule(total_calculated_steps)
                    opt = tf.keras.optimizers.AdamW(
                        learning_rate=lr_sched,
                        beta_1=0.9, beta_2=0.98,
                        epsilon=1e-9,
                        clipnorm=1.0
                    )
                    policy = mixed_precision.global_policy()
                    policy_name = policy.name if hasattr(policy, "name") else str(policy)
                    
                    if policy_name.startswith("mixed"):
                        opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
    
                
                    model.compile(optimizer=opt, loss="huber", metrics=["mae"])
                else:
                    print("[resume] model already compiled (optimizer restored)")
                # set optimizer.iterations to saved global_step if present (ensures LR schedule will resume)
                try:
                    gs = meta.get("global_step")
                    if gs is not None:
                        target_step = int(gs)
                        # 1. Check current value first (load_model usually restores this automatically)
                        current_step = int(model.optimizer.iterations.numpy())
                        
                        if current_step == target_step:
                            print(f"[resume] optimizer.iterations is already {current_step}. No action needed.")
                        else:
                            print(f"[resume] Adjusting optimizer.iterations from {current_step} to {target_step}")
                            # 2. Use .assign() directly on the variable, bypassing backend.set_value
                            model.optimizer.iterations.assign(target_step)
                            
                except Exception as e:
                    print(f"[resume] Warning: Could not manually set optimizer iterations: {e}")
                    # This is non-fatal if load_model succeeded
        
                # read saved stage & epoch
                saved_stage = int(meta.get("stage", 1))
                saved_epoch = int(meta.get("epoch", 0))
                resume_info = {"stage": saved_stage, "epoch": saved_epoch, "meta": meta}
                print(f"[resume] Resume prepared: stage={saved_stage}, epoch={saved_epoch}")
            except Exception as e:
                print("[resume] resume attempt failed:", e)
                resume_info = None
        else:
            print("[resume] No previous checkpoint found. Starting fresh.")
    
        # Training loop per stage (reusing model weights)
        time_ckpt.on_train_begin()
        for stage_idx, (seqs, weights) in enumerate(curriculum_stages, 1):
            # e=time.time()
            print(f"\nStage {stage_idx} sequences={seqs} weights={weights}")
            time_ckpt.set_stage(stage_idx)
            # if resuming from a later stage, skip earlier stages
            if resume_info is not None and resume_info["stage"] > stage_idx:
                print(f"[resume] skipping stage {stage_idx} (already completed)")
                continue
            merged_train = tf.data.Dataset.sample_from_datasets(
            [datasets_by_size[s]["train"] for s in seqs], weights, seed=42, stop_on_empty_dataset=True
            )
            # merged_val = tf.data.Dataset.sample_from_datasets(
            #     [datasets_by_size[s]["val"] for s in seqs], weights, seed=42, stop_on_empty_dataset=True
            # )
            merged_train = merged_train.repeat()
            # merged_val = merged_val.repeat()
            stage_train_steps = int(sum(train_steps[s] * w for s, w in zip(seqs, weights)))  # weighted sum
            # stage_val_steps   = int(sum(val_steps[s] * w for s, w in zip(seqs, weights)))
            
            # determine initial_epoch for this stage
            if resume_info is not None and resume_info["stage"] == stage_idx:
                initial_epoch = int(resume_info["epoch"]) + 1   # start from next epoch
                print(f"[resume] Resuming stage {stage_idx} from epoch {initial_epoch}")
            else:
                initial_epoch = 0
            
            history = model.fit(
                merged_train,
                epochs=EPOCHS,
                initial_epoch=initial_epoch,
                steps_per_epoch=stage_train_steps,
                callbacks=[time_ckpt],
                verbose=1
            )

## (optional) aggregating weights from the last three epochs of the pre-training 
common practice in LLM/MLM to make model generalise better, however we did not observed any imrovement.  

In [None]:
if USE_FINE_TUNE==False and USE_ABLATION==False:
    # Parameters: adapt as needed
    BASE_DIR = "/kaggle/working"         # directory where checkpoints/meta JSONs live
    BASENAME = "bert"                   # base used in TimeBasedEpochCheckpoint basepath (e.g. bert -> bert_stage*.keras and bert_*.meta.json)
    N_LAST = 3                          # number of last checkpoints to average
    OUTPATH = "/kaggle/working/BERT_Final_Averaged_Model.keras"  # final save path
    
    custom_objects = {
        "PositionalEmbedding": PositionalEmbedding,
        "WarmUpCosine": WarmUpCosine,
        "padfloat_to_attnbool": padfloat_to_attnbool,
        "LossScaleOptimizer": tf.keras.mixed_precision.LossScaleOptimizer,
    }
    
    # 1) Find last N meta files (most recent modification time)
    pattern = os.path.join(BASE_DIR, f"{BASENAME}_stage*_e*_t*.keras.meta.json")
    meta_paths = sorted(glob.glob(pattern), key=os.path.getmtime, reverse=True)
    meta_paths = meta_paths[:N_LAST]
    
    if not meta_paths:
        raise RuntimeError(f"No checkpoint metadata found with pattern {pattern}")
    
    print("Averaging these checkpoint metas (newest first):")
    for p in meta_paths:
        print(" ", p)
    
    # 2) Load weights and accumulate
    accum_weights = None
    count = 0
    for meta_path in meta_paths:
        with open(meta_path, "r") as f:
            meta = json.load(f)
        ckpt_path = meta.get("ckpt_path") or meta.get("ckpt_dir")  # support variations
        if not ckpt_path:
            print(f"Skipping meta {meta_path}: no ckpt_path found")
            continue
        if not os.path.exists(ckpt_path):
            print(f"Skipping meta {meta_path}: ckpt path does not exist: {ckpt_path}")
            continue
    
        print(f"Loading checkpoint: {ckpt_path}")
        # load saved model (weights + architecture); use compile=False to avoid loading optimizer
        try:
            m = tf.keras.models.load_model(ckpt_path, custom_objects=custom_objects, compile=False)
        except Exception as e:
            # fallback: if the saved artifact contains only weights in a single file, try load_weights on fresh model
            print(f"load_model failed for {ckpt_path} with error: {e}")
            print("Attempting to load weights via build_model() + load_weights(...)")
            m = build_model()
            m.load_weights(ckpt_path)  # may still fail; let exception bubble if so
    
        w = m.get_weights()
        # convert to numpy arrays with a consistent dtype (float64 accumulation)
        w = [x.astype(np.float64) for x in w]
    
        if accum_weights is None:
            accum_weights = [np.zeros_like(x, dtype=np.float64) for x in w]
    
        if len(w) != len(accum_weights):
            raise RuntimeError(f"Weight length mismatch for checkpoint {ckpt_path}: expected {len(accum_weights)} arrays, got {len(w)}")
    
        for i in range(len(w)):
            accum_weights[i] += w[i]
    
        count += 1
    
    if count == 0:
        raise RuntimeError("No valid checkpoints were loaded for averaging")
    
    # 3) Compute average and cast back to float32
    avg_weights = [ (w / float(count)).astype(np.float32) for w in accum_weights ]
    
    # 4) Build final model and set averaged weights
    final_model = build_model()   # ensure this uses same args as original model
    final_model.set_weights(avg_weights)
    
    # 5) Save final averaged model (SavedModel / .keras format)
    final_model.save(OUTPATH, overwrite=True)
    print(f"Final averaged model saved to: {OUTPATH} (averaged {count} checkpoints)")

## Plotting ablation result 
make sure variable 'USE_ABLATION' was set to True 

In [None]:
# --------------------------
# Plot train/val loss curves
# --------------------------
if USE_FINE_TUNE==False and USE_ABLATION==True:
    def plot_ablation_histories(all_histories, metric="loss", epoch_limit=None, figsize=(8,5)):
        """
        all_histories: dict mapping seq_len -> history dict (history['loss'] and history['val_loss'] present)
        metric: "loss" or "mae" (plots metric and 'val_' + metric)
        epoch_limit: optional int to truncate x-axis
        """
        seqs = sorted(all_histories.keys())
        # pick a color cycle with enough distinct colors
        cmap = plt.get_cmap("tab10")
        colors = [cmap(i % 10) for i in range(len(seqs))]
    
        # determine longest epoch length across histories
        max_epochs = max(len(all_histories[s][metric]) for s in seqs)
        if epoch_limit is not None:
            max_epochs = min(max_epochs, epoch_limit)
        x = np.arange(1, max_epochs + 1)
    
        plt.figure(figsize=figsize)
        band_patches = []  # for legend entries of bands
    
        for i, seq in enumerate(seqs):
            hist = all_histories[seq]
            train = np.array(hist[metric])
            val = np.array(hist[f"val_{metric}"])
    
            # pad to max_epochs by repeating last value if needed
            if train.size < max_epochs:
                train = np.pad(train, (0, max_epochs - train.size), mode='edge')
            if val.size < max_epochs:
                val = np.pad(val, (0, max_epochs - val.size), mode='edge')
    
            # optionally truncate
            train = train[:max_epochs]
            val = val[:max_epochs]
    
            # lower and upper bounds for the band
            lower = np.minimum(train, val)
            upper = np.maximum(train, val)
    
            color = colors[i]
            label_train = f"Train ({seq})"
            label_val   = f"Val ({seq})"
    
            # fill band between train and val
            plt.fill_between(x, lower, upper, color=color, alpha=0.12, linewidth=0, zorder=1)
    
            # plot the lines on top (train solid, val dashed)
            plt.plot(x, train, linestyle='-', linewidth=1.5, color=color, alpha=0.9, zorder=3)
            plt.plot(x, val,   linestyle='--', linewidth=1.2, color=color, alpha=0.9, zorder=4)
    
            # small legend patch for band only (optional)
            band_patches.append(Patch(facecolor=color, alpha=0.12, label=f"Loss = {seq}"))
        if metric=="mae":
            plt.title(f"Train & Validation MAE ‚Äî Loss function ablation comparison")
        else:
            plt.title(f"Train & Validation {metric.capitalize()} ‚Äî Loss function ablation comparison")
        plt.xlabel("Epoch",fontsize=13)
        if metric=="mae":
            plt.ylabel("MAE",fontsize=13)
        else:
            plt.ylabel(metric.capitalize(),fontsize=13)
        plt.grid(True, linestyle=":", alpha=0.5)
    
        # construct legends:
        #  - one for bands describing sequence lengths
        #  - one for line styles (train/val)
        # proxies for line styles
        from matplotlib.lines import Line2D
        line_proxies = [
            Line2D([0], [0], color='k', linestyle='-', linewidth=1.5, label='Train'),
            Line2D([0], [0], color='k', linestyle='--', linewidth=1.2, label='Val')
        ]
    
        # top-left: style legend; top-right: seq legend (bands)
        l1 = plt.legend(handles=line_proxies, loc='upper left')
        l2 = plt.legend(handles=band_patches, loc='upper right', ncol=1, framealpha=0.9)
        plt.gca().add_artist(l1)  # keep first legend
        plt.tight_layout()
        plt.show()
        
    plot_ablation_histories(all_histories, metric="loss", epoch_limit=EPOCHS)
    plot_ablation_histories(all_histories, metric="mae",  epoch_limit=EPOCHS)

# Fine-tuning module
Make sure variable 'USE_FINE_TUNE' is set to True

In [None]:
# test
if USE_FINE_TUNE:
    tf.keras.mixed_precision.set_global_policy('float32')
    # --------------------------------
    # 2. LORA IMPLEMENTATION
    # --------------------------------
    class LoRALayer(tf.keras.layers.Wrapper):
        def __init__(self, layer, rank=4, alpha=4, **kwargs):
            super().__init__(layer, **kwargs)
            self.rank = int(rank)
            self.alpha = float(alpha)
            self.scaling = None
    
        def build(self, input_shape):
            in_dim = int(input_shape[-1])
            out_dim = self.layer.units 
            
            self.lora_A = self.add_weight(
                name=self.layer.name + "_lora_A",
                shape=(in_dim, self.rank),
                initializer="random_normal",
                trainable=True,
                dtype="float32"
            )
            self.lora_B = self.add_weight(
                name=self.layer.name + "_lora_B",
                shape=(self.rank, out_dim),
                initializer="zeros",
                trainable=True,
                dtype="float32"
            )
            self.scaling = self.alpha / self.rank
            super().build(input_shape)
    
        def call(self, inputs, **kwargs):
            x_32 = tf.cast(inputs, tf.float32)
            base_out = self.layer(x_32, **kwargs)
            lora_out = tf.matmul(x_32, self.lora_A)
            lora_out = tf.matmul(lora_out, self.lora_B) * self.scaling
            return tf.cast(base_out, tf.float32) + lora_out
            
        def get_config(self):
            config = super().get_config()
            config.update({"rank": self.rank, "alpha": self.alpha})
            return config
    
    def add_lora_to_encoder(encoder, rank=4, alpha=8):
        def replace_dense(layer):
            # A. Replace Dense with LoRA
            if isinstance(layer, tf.keras.layers.Dense):
                cfg = layer.get_config()
                cfg.pop('dtype', None); cfg.pop('dtype_policy', None); cfg['dtype'] = 'float32'
                new_dense = tf.keras.layers.Dense.from_config(cfg)
                return LoRALayer(new_dense, rank=rank, alpha=alpha)
            # B. Keep other layers (Prevents NoneType error on Lambda/Custom)
            return layer
    
        new_encoder = tf.keras.models.clone_model(encoder, clone_function=replace_dense)
        
        print("Copying weights...")
        for old_lyr, new_lyr in zip(encoder.layers, new_encoder.layers):
            if isinstance(new_lyr, LoRALayer):
                new_lyr.layer.set_weights(old_lyr.get_weights())
    
        print("Freezing base layers...")
        for layer in new_encoder.layers:
            layer.trainable = False
            if isinstance(layer, LoRALayer):
                layer.trainable = True
                layer.layer.trainable = False
        return new_encoder
    
    # --------------------------------
    # 3. DATASET & TRAINING SETUP
    # --------------------------------
    def parse_record(example):
        feature_description = {
            "x": tf.io.FixedLenFeature([fine_tune_seq_len * NUM_FEATURES], tf.float32),
            "y": tf.io.FixedLenFeature([fine_tune_seq_len * NUM_LABELS], tf.float32),
            "mask": tf.io.FixedLenFeature([fine_tune_seq_len], tf.float32),
        }
        ex = tf.io.parse_single_example(example, feature_description)
        x = tf.reshape(ex["x"], (fine_tune_seq_len, NUM_FEATURES))
        y = tf.reshape(ex["y"], (fine_tune_seq_len, NUM_LABELS))
        mask = tf.reshape(ex["mask"], (fine_tune_seq_len,))
        return (x, mask), y, mask
    
    def make_dataset(path, batch_size, training,drop_remainder):
        ds = tf.data.TFRecordDataset(path, num_parallel_reads=AUTOTUNE)
        ds = ds.map(parse_record, num_parallel_calls=AUTOTUNE)
        if training:
            ds = ds.shuffle(2000, seed=42)
        return ds.batch(batch_size, drop_remainder=drop_remainder).prefetch(AUTOTUNE)


    def parse_record2(example):
        """
        Parses a single TFRecord example, now including doc_id and veh_id.
        """
        feature_description = {
            "x": tf.io.FixedLenFeature([fine_tune_seq_len * NUM_FEATURES], tf.float32),
            "y": tf.io.FixedLenFeature([fine_tune_seq_len * NUM_LABELS], tf.float32),
            "mask": tf.io.FixedLenFeature([fine_tune_seq_len], tf.float32),
            # NEW: Features for the string IDs (stored as bytes in the TFRecord)
            "doc_id": tf.io.FixedLenFeature([], tf.string),
            "veh_id": tf.io.FixedLenFeature([], tf.string),
        }
        
        ex = tf.io.parse_single_example(example, feature_description)
        
        x = tf.reshape(ex["x"], (fine_tune_seq_len, NUM_FEATURES))
        y = tf.reshape(ex["y"], (fine_tune_seq_len, NUM_LABELS))
        mask = tf.reshape(ex["mask"], (fine_tune_seq_len,))
        
        doc_id = ex["doc_id"]
        veh_id = ex["veh_id"]
        
        # Return 5 items: ((features), labels, sample_weight, ID1, ID2)
        return (x, mask), y, mask, doc_id, veh_id
    
    def make_dataset2(path, batch_size, training, drop_remainder):
        """Creates a TF Dataset from TFRecord paths."""
        ds = tf.data.TFRecordDataset(path, num_parallel_reads=AUTOTUNE)
        # The map function now returns 5 elements
        ds = ds.map(parse_record2, num_parallel_calls=AUTOTUNE)
        return ds.batch(batch_size, drop_remainder=drop_remainder).prefetch(AUTOTUNE)

    
    
    def build_finetune_model(encoder):
        inp = tf.keras.Input(shape=(fine_tune_seq_len, NUM_FEATURES), dtype=tf.float32)
        mask = tf.keras.Input(shape=(fine_tune_seq_len,), dtype=tf.float32)
        x = encoder([inp, mask])
        out = tf.keras.layers.Dense(NUM_LABELS, activation="sigmoid", dtype="float32")(x)
        return tf.keras.Model([inp, mask], out)

    # [Keep your WarmUpCosine, PositionalEmbedding, padfloat_to_attnbool, custom_objects here]

    class PositionalEmbedding(tf.keras.layers.Layer):
        def __init__(self, max_seq_len, d_model, **kwargs):
            super().__init__(**kwargs)
            self.max_seq_len = max_seq_len
            self.d_model = d_model
            self.pos_emb = tf.keras.layers.Embedding(
                input_dim=max_seq_len,
                output_dim=d_model,
                dtype="float32",
                name="pos_embedding"
            )
    
        def build(self, input_shape):
            # Explicitly build the internal Embedding layer
            self.pos_emb.build((None,))
            super().build(input_shape)
    
        def call(self, x):
            seq_len = tf.shape(x)[1]
            positions = tf.range(seq_len)
            pos_embeddings = self.pos_emb(positions)
            pos_embeddings = tf.cast(pos_embeddings, x.dtype)
            return x + pos_embeddings
    
        def get_config(self):
            config = super().get_config()
            config.update({
                "max_seq_len": self.max_seq_len,
                "d_model": self.d_model,
            })
            return config



    
    def padfloat_to_attnbool(m):
        return tf.expand_dims(tf.expand_dims(tf.cast(m > 0.5, tf.bool), 1), 1)

    class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, base_lr, total_steps, warmup_steps, min_lr=1e-6):
            super().__init__()
            self.base_lr = base_lr
            self.warmup_steps = warmup_steps
            self.total_steps = total_steps
            self.min_lr = min_lr
    
        def __call__(self, step):
            # Linear warmup
            warmup_lr = self.base_lr * (tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
            
            # Cosine decay after warmup
            progress = (tf.cast(step - self.warmup_steps, tf.float32) /
                        tf.cast(self.total_steps - self.warmup_steps, tf.float32))
            cosine_decay = 0.5 * (1 + tf.cos(np.pi * tf.clip_by_value(progress, 0.0, 1.0)))
            cosine_lr = self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
            
            return tf.cond(step < self.warmup_steps, lambda: warmup_lr, lambda: cosine_lr)
            
        def get_config(self): # <--- THIS IS THE REQUIRED FIX
            """Returns the serializable configuration of the schedule."""
            return {
                "base_lr": self.base_lr,
                "total_steps": self.total_steps,
                "warmup_steps": self.warmup_steps,
                "min_lr": self.min_lr,
            }
    
        # Optional: Keras automatically handles from_config if get_config returns
        # arguments matching the __init__ signature, but defining it is safer.
        @classmethod
        def from_config(cls, config):
            return cls(**config)
    
    def get_lr_schedule(stage_steps):
        """Create a new Warmup + Cosine Decay schedule per stage."""
        warmup_steps = int(0.01 * stage_steps)
        return WarmUpCosine(
            base_lr=1e-3,
            total_steps=stage_steps,
            warmup_steps=warmup_steps,
            min_lr=1e-6
        )


    class MaskedBinaryCrossentropy(tf.keras.losses.Loss):
        def __init__(self, from_logits=False, name="masked_bce"):
            super().__init__(name=name)
            self.from_logits = from_logits
            self.bce = tf.keras.losses.BinaryCrossentropy(
                from_logits=from_logits,
                reduction=tf.keras.losses.Reduction.NONE
            )
    
        def call(self, y_true, y_pred, sample_weight=None):
            """
            y_true: (B, T, L)
            y_pred: (B, T, L)
            sample_weight (mask): (B, T)
            """
            # Per-label BCE ‚Üí (B, T, L)
            loss = self.bce(y_true, y_pred)
    
            # Reduce labels ‚Üí (B, T)
            loss = tf.reduce_mean(loss, axis=-1)
    
            if sample_weight is not None:
                sample_weight = tf.cast(sample_weight, loss.dtype)
                loss = loss * sample_weight
                return tf.reduce_sum(loss) / (tf.reduce_sum(sample_weight) + 1e-8)
    
            # Fallback (should not happen)
            return tf.reduce_mean(loss)


    #-------------- F1 Score evaluation-------#
    def collect_predictions(model, dataset):
        """
        Collects predictions from the model, now also collecting document and vehicle IDs.
        """
        y_true_all = []
        y_pred_all = []
        mask_all = []
        doc_id_all = [] # New list for Document IDs
        veh_id_all = [] # New list for Vehicle IDs
    
        # NOTE: The loop now unpacks 5 values from the dataset element
        for (x, attn_mask), y, valid_mask, doc_id, veh_id in dataset:
            y_pred = model.predict_on_batch((x, attn_mask))
            
            y_true_all.append(y.numpy())
            y_pred_all.append(y_pred)
            mask_all.append(valid_mask.numpy())
            
            # Collect IDs (which are byte strings)
            doc_id_all.append(doc_id) # No .numpy() needed if already numpy array
            veh_id_all.append(veh_id) # No .numpy() needed if already numpy array
    
        return (
            np.concatenate(y_true_all, axis=0),
            np.concatenate(y_pred_all, axis=0),
            np.concatenate(mask_all, axis=0),
            np.concatenate(doc_id_all, axis=0), # Concatenated array of sequence doc IDs
            np.concatenate(veh_id_all, axis=0), # Concatenated array of sequence veh IDs
        )
    
    def collect_predictions2(model, dataset):
        y_true_all = []
        y_pred_all = []
        mask_all = []
    
        for (x, attn_mask), y, valid_mask in dataset:
            y_pred = model.predict_on_batch((x, attn_mask))
            y_true_all.append(y.numpy())
            y_pred_all.append(y_pred)
            mask_all.append(valid_mask.numpy())
    
        return (
            np.concatenate(y_true_all, axis=0),
            np.concatenate(y_pred_all, axis=0),
            np.concatenate(mask_all, axis=0),
        )

    def F1_eval(model,test_ds,test_ds2):
        # --- 1. THRESHOLD FINDING STAGE ---
        
        # Collect predictions - UPDATED to unpack 5 return values
        y_test_thr, y_pred_thr, mask_thr = collect_predictions2(model, test_ds)
        
        # Flatten
        y_test_thr_flat  = y_test_thr.reshape(-1, NUM_LABELS)
        y_pred_thr_flat  = y_pred_thr.reshape(-1, NUM_LABELS)
        mask_thr_flat    = mask_thr.reshape(-1) > 0.5
        
        # Apply mask 
        y_test_thr_clean = y_test_thr_flat[mask_thr_flat]
        y_pred_thr_clean = y_pred_thr_flat[mask_thr_flat]
        
        # --- Threshold finding logic ---
        best_thresholds = []
        
        for i in range(NUM_LABELS):
            best_f1, best_t = 0.0, 0.5
        
            for t in np.linspace(0.1, 0.9, 17):
                y_bin = (y_pred_thr_clean[:, i] > t).astype(int)
                f1 = f1_score(
                    y_test_thr_clean[:, i],
                    y_bin,
                    zero_division=0
                )
        
                if f1 > best_f1:
                    best_f1, best_t = f1, t
        
            best_thresholds.append(best_t)
        
        # --- 2. FINAL METRICS STAGE ---
        
        # Collect FINAL test predictions - UPDATED to unpack 5 return values
        y_test2, y_pred_test2, mask_test2, doc_ids_test2, veh_ids_test2 = collect_predictions(model, test_ds2)
        
        # Flatten
        y_test2_flat  = y_test2.reshape(-1, NUM_LABELS)
        y_pred_test2_flat = y_pred_test2.reshape(-1, NUM_LABELS)
        mask_test2_flat = mask_test2.reshape(-1) > 0.5
        
        # Apply mask
        y_test2_clean = y_test2_flat[mask_test2_flat]
        y_pred_test2_clean = y_pred_test2_flat[mask_test2_flat]
        
        # Calculate flattened binary predictions (for classification report)
        y_pred_bin = np.zeros_like(y_pred_test2_clean, dtype=int)
        for i, t in enumerate(best_thresholds):
            y_pred_bin[:, i] = (y_pred_test2_clean[:, i] > t).astype(int)
        
        # --- FIX: DEFINE SEQUENCE-LEVEL BINARY PREDICTIONS ---
        # The pretty_transition_report needs predictions in the original sequence shape (N_seq, seq_len, NUM_LABELS)
        
        # 1. Create a broadcastable array of thresholds
        threshold_array = np.array(best_thresholds).reshape(1, 1, NUM_LABELS)
        
        # 2. Apply thresholds to the unflattened predictions
        y_pred_bin_seq = (y_pred_test2 > threshold_array).astype(int)
        # --- END FIX ---
        report_dict = classification_report(
            y_test2_clean.astype(int),
            y_pred_bin,
            target_names=LABEL_COLS,
            zero_division=0,
            output_dict=True
        )
        h_loss = hamming_loss(y_test2_clean, y_pred_bin)
        macro_f1 = report_dict.get('macro avg', {}).get('f1-score')
        return h_loss,macro_f1

    #---------------F1 Score evaluation end---#
    
    fine_tune_seq_len = 192  # Must match PRE-TRAINED model length
    
    custom_objects = {
    "PositionalEmbedding": PositionalEmbedding,
    "WarmUpCosine": WarmUpCosine,
    "padfloat_to_attnbool": padfloat_to_attnbool,
    "LossScaleOptimizer": tf.keras.mixed_precision.LossScaleOptimizer,
    }
    
    train_tfr =  ["/kaggle/input/tfrecords-finetuning-192/train.tfrecord"] # Ensure path matches your write_split output
    val_tfr   =  ["/kaggle/input/tfrecords-finetuning-192/val.tfrecord"]
    test_tfr   = ["/kaggle/input/tfrecords-finetuning-192/test.tfrecord"]
    test_tfr2 =  ["/kaggle/input/tfrecords-testing-192-doc-veh/test.tfrecord"]
  

    import time
    s=time.time()
    # b=[4,8,16,32]
    # b=[8]
    c=[(4,4)]
    for a,b in c:
        BATCH_SIZE = 4
        EPOCHS = 100
    
        train_ds = make_dataset(train_tfr, BATCH_SIZE, training=True,drop_remainder=True)
        val_ds   = make_dataset(val_tfr, BATCH_SIZE, training=True,drop_remainder=False)
        test_ds   = make_dataset(test_tfr, BATCH_SIZE, training=False,drop_remainder=False)
        test_ds2   = make_dataset2(test_tfr2, BATCH_SIZE, training=False,drop_remainder=False)
        
        
        SEEDS = [42]
        # SEEDS = [42]
        
        results = []
        for seed in SEEDS:
            tf.random.set_seed(seed)
            np.random.seed(seed)
            pretrained = tf.keras.models.load_model(
                "/kaggle/input/bert-final-2/bert_stage3_e11_t1765702225.keras",
                custom_objects=custom_objects, 
                compile=False
            )
            
            encoder_lora = add_lora_to_encoder(pretrained, rank=a, alpha=b)
            model = build_finetune_model(encoder_lora)
            
            steps = EPOCHS * sum(1 for _ in train_ds)
            lr = get_lr_schedule(steps)
            
            model.compile(
                optimizer=tf.keras.optimizers.AdamW(lr), 
                # loss="binary_crossentropy", 
                loss = MaskedBinaryCrossentropy(from_logits=False),
                metrics=[metrics.AUC(name='auprc', curve='PR')]
            )
        
            callback_list = [
                tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=8, restore_best_weights=True),                
            ]
            
            history = model.fit(
                train_ds,
                validation_data=val_ds,
                epochs=EPOCHS,
                callbacks=callback_list,
                verbose=0
            )
            test_results = model.evaluate(test_ds, verbose=0)
            h_loss,macro_f1 = F1_eval(model,test_ds,test_ds2)
            results.append({
            # 'loss': test_results[0],
            'auprc': test_results[1],
            'hamming_loss': float(f"{h_loss:.4f}"),
            'macro_f1': float(f"{macro_f1:.4f}")
            })

            
        auprc_mean = np.mean([r['auprc'] for r in results])
        auprc_std = np.std([r['auprc'] for r in results])
        f1_mean = np.mean([r['macro_f1'] for r in results])
        f1_std = np.std([r['macro_f1'] for r in results])
        h_mean = np.mean([r['hamming_loss'] for r in results])
        h_std = np.std([r['hamming_loss'] for r in results])
        print("config",b)
        print(f"AUPRC: {auprc_mean:.4f} ¬± {auprc_std:.4f}")
        print(f"f1: {f1_mean:.4f} ¬± {f1_std:.4f}")
        print(f"Hamming loss: {h_mean:.4f} ¬± {h_std:.4f}")
        
        
        print("trainable_weights",sum(w.numpy().size for w in model.trainable_weights))
        e = time.time()
        print(f"Training finished. Elapsed time: {(e-s)/60:.2f} minutes")

## Evaluation of the fine-tuned BERT

In [None]:
if USE_FINE_TUNE:
    def collect_predictions(model, dataset):
        """
        Collects predictions from the model, now also collecting document and vehicle IDs.
        """
        y_true_all = []
        y_pred_all = []
        mask_all = []
        doc_id_all = [] # New list for Document IDs
        veh_id_all = [] # New list for Vehicle IDs
    
        # NOTE: The loop now unpacks 5 values from the dataset element
        for (x, attn_mask), y, valid_mask, doc_id, veh_id in dataset:
            y_pred = model.predict_on_batch((x, attn_mask))
            
            y_true_all.append(y.numpy())
            y_pred_all.append(y_pred)
            mask_all.append(valid_mask.numpy())
            
            # Collect IDs (which are byte strings)
            doc_id_all.append(doc_id) # No .numpy() needed if already numpy array
            veh_id_all.append(veh_id) # No .numpy() needed if already numpy array
    
        return (
            np.concatenate(y_true_all, axis=0),
            np.concatenate(y_pred_all, axis=0),
            np.concatenate(mask_all, axis=0),
            np.concatenate(doc_id_all, axis=0), # Concatenated array of sequence doc IDs
            np.concatenate(veh_id_all, axis=0), # Concatenated array of sequence veh IDs
        )
    
    def collect_predictions2(model, dataset):
        y_true_all = []
        y_pred_all = []
        mask_all = []
    
        for (x, attn_mask), y, valid_mask in dataset:
            y_pred = model.predict_on_batch((x, attn_mask))
            y_true_all.append(y.numpy())
            y_pred_all.append(y_pred)
            mask_all.append(valid_mask.numpy())
    
        return (
            np.concatenate(y_true_all, axis=0),
            np.concatenate(y_pred_all, axis=0),
            np.concatenate(mask_all, axis=0),
        )
    
    # --- 1. THRESHOLD FINDING STAGE ---
    
    # Collect predictions - UPDATED to unpack 5 return values
    y_test_thr, y_pred_thr, mask_thr = collect_predictions2(model, test_ds)
    
    # Flatten
    y_test_thr_flat  = y_test_thr.reshape(-1, NUM_LABELS)
    y_pred_thr_flat  = y_pred_thr.reshape(-1, NUM_LABELS)
    mask_thr_flat    = mask_thr.reshape(-1) > 0.5
    
    # Apply mask 
    y_test_thr_clean = y_test_thr_flat[mask_thr_flat]
    y_pred_thr_clean = y_pred_thr_flat[mask_thr_flat]
    
    # --- Threshold finding logic ---
    best_thresholds = []
    
    for i in range(NUM_LABELS):
        best_f1, best_t = 0.0, 0.5
    
        for t in np.linspace(0.1, 0.9, 17):
            y_bin = (y_pred_thr_clean[:, i] > t).astype(int)
            f1 = f1_score(
                y_test_thr_clean[:, i],
                y_bin,
                zero_division=0
            )
    
            if f1 > best_f1:
                best_f1, best_t = f1, t
    
        best_thresholds.append(best_t)
    
    print("Optimal thresholds (from test_ds):")
    print(best_thresholds)
    
    # --- 2. FINAL METRICS STAGE ---
    
    # Collect FINAL test predictions - UPDATED to unpack 5 return values
    y_test2, y_pred_test2, mask_test2, doc_ids_test2, veh_ids_test2 = collect_predictions(model, test_ds2)
    
    # Flatten
    y_test2_flat  = y_test2.reshape(-1, NUM_LABELS)
    y_pred_test2_flat = y_pred_test2.reshape(-1, NUM_LABELS)
    mask_test2_flat = mask_test2.reshape(-1) > 0.5
    
    # Apply mask
    y_test2_clean = y_test2_flat[mask_test2_flat]
    y_pred_test2_clean = y_pred_test2_flat[mask_test2_flat]
    
    # Calculate flattened binary predictions (for classification report)
    y_pred_bin = np.zeros_like(y_pred_test2_clean, dtype=int)
    for i, t in enumerate(best_thresholds):
        y_pred_bin[:, i] = (y_pred_test2_clean[:, i] > t).astype(int)
    
    # --- FIX: DEFINE SEQUENCE-LEVEL BINARY PREDICTIONS ---
    # The pretty_transition_report needs predictions in the original sequence shape (N_seq, seq_len, NUM_LABELS)
    
    # 1. Create a broadcastable array of thresholds
    threshold_array = np.array(best_thresholds).reshape(1, 1, NUM_LABELS)
    
    # 2. Apply thresholds to the unflattened predictions
    y_pred_bin_seq = (y_pred_test2 > threshold_array).astype(int)
    # --- END FIX ---
    
    print("FINAL TEST RESULTS (test_ds2 - Frame-level)")
    print(classification_report(
        y_test2_clean.astype(int),
        y_pred_bin,
        target_names=LABEL_COLS,
        zero_division=0
    ))
    
    print("Hamming Loss:", hamming_loss(y_test2_clean, y_pred_bin))

In [None]:
if USE_FINE_TUNE:
    dependent_pairs = {
        "Cut-in in front of ego vehicle": [
            "Lead vehicle accelerating",
            "Lead vehicle cruising",
            "Lead vehicle decelerating"
        ],
        "Cut-out in front of ego vehicle": [
            "Ego vehicle driving in lane without lead vehicle",
            "Lead vehicle cruising",
            "Lead vehicle decelerating",
            "Lead vehicle accelerating"
        ],
        "Ego merging into an occupied lane": [
            "Lead vehicle cruising",
            "Lead vehicle decelerating",
            "Lead vehicle accelerating"
        ],
        "Ego vehicle performing lane change": [
            "Ego vehicle driving in lane without lead vehicle"
        ],
        "Ego vehicle performing lane change with vehicle behind": [
            "Ego vehicle driving in lane without lead vehicle"
        ],
        "Ego vehicle driving in lane without lead vehicle": [
            "Ego vehicle approaching slower lead vehicle",
            "Lead vehicle accelerating",
            "Lead vehicle cruising",
            "Lead vehicle decelerating"
        ]
    }
    
    def compute_boundary_metrics_fast(y_true, y_pred, length=None):
        """Computes Mean Absolute Boundary Error (MABE) for a single sequence."""
        true_boundaries = np.where((y_true[1:] == 1) & (y_true[:-1] == 0))[0] + 1
        pred_boundaries = np.where((y_pred[1:] == 1) & (y_pred[:-1] == 0))[0] + 1
        if len(true_boundaries) == 0 or len(pred_boundaries) == 0:
            return np.nan
        diffs = []
        for tb in true_boundaries:
            idx = np.searchsorted(pred_boundaries, tb)
            cands = []
            if idx > 0: cands.append(abs(tb - pred_boundaries[idx - 1]))
            if idx < len(pred_boundaries): cands.append(abs(tb - pred_boundaries[idx]))
            diffs.append(min(cands))
        mabe = float(np.mean(diffs))
        if length is not None:
            return mabe / length  # <= normalization step
        return mabe
    
    def pretty_transition_report(
        y_sequences: np.ndarray,
        y_pred_proba: np.ndarray = None,
        doc_ids: np.ndarray = None,
        veh_ids: np.ndarray = None,
        labels: List[str] = None,
        best_thresholds: List[float] = None,
        y_pred_bin: np.ndarray = None,
        dependent_pairs: Dict[str, List[str]] = None,
        match_tolerance_prev: int = 1,
        allow_curr_offset: int = 1
    ) -> Dict[str, Any]:
        """Compute boundary MABE, missed scenario-per-vehicle, and transition MABE+missed per dependent pair.
    
        Returns dict with 'boundary_summary', 'missed_percent_vehicle', 'transition_summary', etc.
        """
        # Basic checks
        assert labels is not None, "Provide labels list"
        n_seq, seq_len, n_labels = y_sequences.shape
    
        # Build binary predictions if not provided (this path is redundant with the fix but kept)
        if y_pred_bin is None:
            if y_pred_proba is None or best_thresholds is None:
                raise ValueError("Either y_pred_bin OR (y_pred_proba and best_thresholds) must be provided.")
            thr = np.array(best_thresholds).reshape(1, 1, -1)
            y_pred_bin = (y_pred_proba.astype(np.float32) > thr).astype(np.int32)
        else:
            y_pred_bin = y_pred_bin.astype(np.int32)
    
        # ensure y_sequences ints
        y_sequences = y_sequences.astype(np.int32)
    
        # 1) Build mapping (doc,veh) -> ordered list of sequence indices
        seqs_by_vehicle = defaultdict(list)
        for idx, (d, v) in enumerate(zip(doc_ids, veh_ids)):
            seqs_by_vehicle[(d, v)].append(idx)
    
        # 2) Boundary MABE per label averaged across vehicles
        boundary_results = defaultdict(list)
        for (d, v), seq_indices in seqs_by_vehicle.items():
            y_true_vehicle = np.concatenate([y_sequences[i] for i in seq_indices], axis=0)
            y_pred_vehicle = np.concatenate([y_pred_bin[i] for i in seq_indices], axis=0)
            for i, lab in enumerate(labels):
                mabe = compute_boundary_metrics_fast(y_true_vehicle[:, i], y_pred_vehicle[:, i],length=len(y_true_vehicle))
                if not np.isnan(mabe):
                    boundary_results[lab].append(mabe)
        boundary_summary = {lab: float(np.mean(vals)) if len(vals) > 0 else np.nan
                                 for lab, vals in boundary_results.items() for _ in [0]}
        # preserve order and fill missing labels
        boundary_summary = {lab: (boundary_summary.get(lab, np.nan)) for lab in labels}
        avg_mabe = float(np.nanmean([v for v in boundary_summary.values() if not np.isnan(v)])) if any(
            not np.isnan(v) for v in boundary_summary.values()) else np.nan
    
        # 3) Missed Scenario Detection Rate per vehicle
        missed_counts_vehicle = defaultdict(int)
        total_support_vehicle = defaultdict(int)
        for (d, v), seq_indices in seqs_by_vehicle.items():
            y_true_vehicle = np.concatenate([y_sequences[i] for i in seq_indices], axis=0)
            y_pred_vehicle = np.concatenate([y_pred_bin[i] for i in seq_indices], axis=0)
            for i, lab in enumerate(labels):
                if y_true_vehicle[:, i].sum() > 0:
                    total_support_vehicle[lab] += 1
                    if y_pred_vehicle[:, i].sum() == 0:
                        missed_counts_vehicle[lab] += 1
        missed_percent_vehicle = {lab: (100.0 * missed_counts_vehicle.get(lab, 0) / total_support_vehicle.get(lab, 1))
                                  if total_support_vehicle.get(lab, 0) > 0 else np.nan for lab in labels}
    
        # Simplified per-vehicle transition evaluation WITHOUT match_tolerance_prev
        label2idx = {l: i for i, l in enumerate(labels)}
        transition_results = defaultdict(list)     # per-pair list of per-instance distances
        transition_support = defaultdict(int)      # number of true transition instances (actual)
        missed_transitions = defaultdict(int)      # number of true instances missed
        
        allow_curr_offset = 1  # allow predicted curr_start == pred_prev_end or pred_prev_end + 1
        
        def ends(idx_seq):    # 1->0 ends, returns indices of the frame after end
            return np.where((idx_seq[:-1] == 1) & (idx_seq[1:] == 0))[0] + 1
        
        def starts(idx_seq): # 0->1 starts, returns indices
            return np.where((idx_seq[1:] == 1) & (idx_seq[:-1] == 0))[0] + 1
        
        # loop per vehicle
        for (d, v), seq_indices in seqs_by_vehicle.items():
            # reconstruct whole vehicle timeline
            y_true_vehicle = np.concatenate([y_sequences[i] for i in seq_indices], axis=0)
            y_pred_vehicle = np.concatenate([y_pred_bin[i] for i in seq_indices], axis=0)
            vehicle_len = len(y_true_vehicle)
        
            for prev_label, next_labels in (dependent_pairs or {}).items():
                if prev_label not in label2idx:
                    continue
                p_idx = label2idx[prev_label]
                pred_prev_end = ends(y_pred_vehicle[:, p_idx])
        
                for curr_label in next_labels:
                    if curr_label not in label2idx:
                        continue
                    c_idx = label2idx[curr_label]
        
                    true_prev_end = ends(y_true_vehicle[:, p_idx])
                    true_curr_start = starts(y_true_vehicle[:, c_idx])
                    pred_curr_start = starts(y_pred_vehicle[:, c_idx])
        
                    # build predicted pairs in vehicle: (pred_prev_end -> pred_curr_start) where curr==prev or prev+1
                    pred_pairs_curr = []
                    if len(pred_prev_end) > 0 and len(pred_curr_start) > 0:
                        for p in pred_prev_end:
                            matches = pred_curr_start[(pred_curr_start == p) | (pred_curr_start == p + allow_curr_offset)]
                            if len(matches) > 0:
                                pred_pairs_curr.extend(list(matches))  # collect predicted curr_starts
        
                    # For each true prev_end that actually has a following true curr_start (0 or +1)
                    for te in true_prev_end:
                        if not (np.any(true_curr_start == te) or np.any(true_curr_start == te + 1)):
                            continue  # not a true prev->curr instance
        
                        transition_support[(prev_label, curr_label)] += 1
        
                        if len(pred_pairs_curr) == 0:
                            # no predicted dependent pairs anywhere in this vehicle -> missed
                            missed_transitions[(prev_label, curr_label)] += 1
                            continue
        
                        # choose the true curr_start for this instance (prefer te, else te+1)
                        true_cs_cands = true_curr_start[(true_curr_start == te) | (true_curr_start == te + 1)]
                        if len(true_cs_cands) > 0:
                            true_cs = int(true_cs_cands[0])
                        else:
                            # fallback: nearest true curr start
                            true_cs = int(true_curr_start[np.argmin(np.abs(true_curr_start - te))]) if len(true_curr_start) > 0 else None
        
                        if true_cs is None:
                            missed_transitions[(prev_label, curr_label)] += 1
                            continue
        
                        # compute distance to nearest predicted curr_start among vehicle's predicted pairs
                        pred_cs_arr = np.array(pred_pairs_curr)
                        best_dist = np.min(np.abs(pred_cs_arr - true_cs)) / vehicle_len
                        transition_results[(prev_label, curr_label)].append(best_dist)
        
        # Final per-pair MABE (mean of per-instance distances)
        transition_summary = {
            pair: float(np.mean(dists)) if len(dists) > 0 else 0
            for pair, dists in transition_results.items()
        }
    
        # compute missed rates and print results
        print("\nBoundary metrics per class (averaged across vehicles):")
        for lab in labels:
            mabe = boundary_summary.get(lab, np.nan)
            if np.isnan(mabe):
                print(f"{lab}: MABE = n/a")
            else:
                print(f"{lab}: MABE = {mabe:.2f}")
        print(f"\nAverage MABE across all classes: {avg_mabe:.4f}\n")
    
        print("Missed Scenario Detection Rates (% of vehicles where class was present but never predicted):")
        for lab in labels:
            pct = missed_percent_vehicle.get(lab, np.nan)
            total_support = total_support_vehicle.get(lab, 0)
            if np.isnan(pct):
                print(f"{lab}: n/a")
            else:
                print(f"{lab}: {pct:.2f}% missed (support: {total_support})")
        avg_miss_rate_vehicle = float(np.nanmean([v for v in missed_percent_vehicle.values() if not np.isnan(v)])) if any(not np.isnan(v) for v in missed_percent_vehicle.values()) else np.nan
        print(f"\nAverage Miss Rate Across All Classes (per vehicle): {avg_miss_rate_vehicle:.2f}%\n")
    
        # overall stats
        avg_transition_mabe = float(np.nanmean([v for v in transition_summary.values() if not np.isnan(v)])) if len(transition_summary) > 0 else np.nan
        total_transitions = sum(transition_support.values())
        total_missed = sum(missed_transitions.values())
        overall_miss_rate = (total_missed / total_transitions * 100) if total_transitions > 0 else 0.0
    
        # --- 5) Summarize transition metrics ---
        print("\nTransition-based MABE (evaluated per vehicle):")
        has_transitions = False
        for prev_label, next_labels in dependent_pairs.items():
            for curr_label in next_labels:
                pair = (prev_label, curr_label)
                total = transition_support.get(pair, 0)
                if total == 0:
                    continue  # skip transitions that never occurred in ground truth
                
                has_transitions = True
                missed = missed_transitions.get(pair, 0)
                mabe_vals = transition_results.get(pair, [])
                mean_mabe = np.mean(mabe_vals) if len(mabe_vals) > 0 else np.nan
                missed_rate = (missed / total * 100) if total > 0 else 0.0
                
                print(f"{prev_label} -> {curr_label}: "
                      f"MABE = {mean_mabe:.4f} | Support = {total} | "
                      f"Missed = {missed} ({missed_rate:.1f}%)")
        
        if has_transitions:
            avg_mabe_pairs = np.nanmean([np.mean(v) for v in transition_results.values() if len(v) > 0])
            total_transitions = sum(transition_support.values())
            total_missed = sum(missed_transitions.values())
            avg_miss_rate_pairs = (total_missed / total_transitions * 100) if total_transitions > 0 else 0
            print(f"\nAverage Transition MABE Across All Pairs: {avg_mabe_pairs:.4f}")
            print(f"Overall Missed Transition Detection Rate (per instance): {avg_miss_rate_pairs:.2f}%")
        else:
            print("\n‚ö†Ô∏è No valid dependent transitions detected in the test set.")
    
        # Return structured results
        return {
            "boundary_summary": boundary_summary,
            "avg_mabe": avg_mabe,
            "missed_percent_vehicle": missed_percent_vehicle,
            "avg_miss_rate_vehicle": avg_miss_rate_vehicle,
            "support_counts_vehicle": {lab: total_support_vehicle.get(lab, 0) for lab in labels},
            "missed_counts_vehicle": {lab: missed_counts_vehicle.get(lab, 0) for lab in labels},
            "transition_summary": transition_summary,
            "transition_support": dict(transition_support),
            "missed_transitions": dict(missed_transitions),
            "avg_transition_mabe": avg_transition_mabe,
            "overall_miss_rate": overall_miss_rate
        }
    
    # -------------------------------
    # CORRECTED USAGE:
    # -------------------------------
    
    print("\n--- Running Transition Report ---")
    result = pretty_transition_report(
        # Use the unflattened ground truth sequences as y_sequences
        y_sequences=y_test2,
        # Use the unflattened probability sequences as y_pred_proba
        y_pred_proba=y_pred_test2,
        # Pass the new ID arrays extracted from collect_predictions
        doc_ids=doc_ids_test2,
        veh_ids=veh_ids_test2,
        labels=LABEL_COLS,
        best_thresholds=best_thresholds,
        y_pred_bin=y_pred_bin_seq, # THIS VARIABLE IS NOW DEFINED ABOVE
        dependent_pairs=dependent_pairs
    )
    
    print("\n--- Full Metrics Report Generated ---")