In [None]:
import os
import glob
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import itertools

# ------------------------------------------------------------
# 0) GPU setup (optional)
# ------------------------------------------------------------
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# ------------------------------------------------------------
# Hyperparameters
# ------------------------------------------------------------
FEATURE_DIM = 29 * 29 * 2
N_LABELS   = 2
BATCH      = 128
EPOCHS     = 30

# AAE-specific
N_L1       = 1024
N_L2       = 768
LATENT_DIM = 64
λ_gp       = 10.0

# Learning rates
LR_AE = 0.0005
LR_DZ = 0.0001
LR_DY = 0.0001
LR_G  = 5e-5

# Architecture options
ACTIVATION = 'elu'
DROPOUT    = 0.2
NORM_TYPE  = 'layer'  # 'layer' or 'batch'

# ------------------------------------------------------------
# 1) Helper functions: preprocessing & TFRecord creation
# ------------------------------------------------------------
datasets = [
    "Attack_free_CHEVROLET_Spark_train.csv",
    "Attack_free_KIA_Soul_train.csv",
    "Flooding_CHEVROLET_Spark_train.csv",
    "Flooding_KIA_Soul_train.csv",
    "Fuzzy_CHEVROLET_Spark_train.csv",
    "Attack_free_HY_Sonata_train.csv",
    "Attack_free_KIA_Soul_train.csv",
    "Fuzzy_dataset_HY_Sonata_train.csv",
    "Fuzzy_dataset_KIA_Soul_train.csv",
    "Malfunction_1st_dataset_HY_Sonata_train.csv",
    "Malfunction_1st_dataset_KIA_Soul_train.csv",
    "Malfunction_2nd_HY_Sonata_train.csv",
    "Malfunction_2nd_KIA_Soul_train.csv",
    "Replay_dataset_HY_Sonata_train.csv",
    "Replay_dataset_KIA_Soul_train.csv"
]
csv_map = {d: d for d in datasets}

def fill_flag(row):
    if not isinstance(row['Flag'], str):
        col = 'Data' + str(int(row['DLC']))
        row['Flag'] = row.get(col, row['Flag'])
    return row

def convert_canid_bits(cid):
    try:
        b = bin(int(str(cid), 16))[2:].zfill(29)
        return np.array(list(map(int, b)), dtype=np.uint8)
    except:
        return np.zeros(29, dtype=np.uint8)

def hex_to_int(x):
    try:
        return int(str(x).strip(), 16)
    except:
        return 0

def preprocess_windows(csv_file):
    print(f"[DATA] Processing {csv_file}")
    attrs = ['Timestamp', 'canID', 'DLC'] + [f'Data{i}' for i in range(8)] + ['Flag']
    df = pd.read_csv(csv_file, header=None, names=attrs, low_memory=False)
    df['Timestamp'] = pd.to_numeric(df['Timestamp'], errors='coerce')
    df['DLC']       = pd.to_numeric(df['DLC'], errors='coerce').fillna(0).astype(int)
    df = df.dropna(subset=['Timestamp', 'canID']).apply(fill_flag, axis=1)
    for i in range(8):
        df[f'Data{i}'] = df[f'Data{i}'].apply(hex_to_int).astype(np.uint8)
    df['Flag']    = df['Flag'].astype(str).str.upper().eq('T').astype(np.uint8)
    df['canBits'] = df['canID'].apply(convert_canid_bits)
    df = df.sort_values('Timestamp')

    bits_all   = np.stack(df['canBits'].values)
    data_bytes = df[[f'Data{i}' for i in range(8)]].values
    flags_all  = df['Flag'].values

    win = 29
    N   = len(bits_all) // win
    bits   = bits_all[:N * win].reshape(N, win, 29)
    data   = data_bytes[:N * win].reshape(N, win, 8)
    flags  = flags_all[:N * win].reshape(N, win)

    rows = []
    for i in range(N):
        id_img   = bits[i].astype(np.uint8)
        last_b   = data[i, -1, :]
        b8       = np.unpackbits(last_b, axis=0).reshape(8,8)
        data_img = cv2.resize(b8.astype(np.float32), (29,29), interpolation=cv2.INTER_NEAREST) > 0.5
        two_ch   = np.stack([id_img, data_img.astype(np.uint8)], axis=-1)
        feat_int = two_ch.flatten().tolist()
        lbl      = int(flags[i].any())
        rows.append((feat_int, lbl))
    return rows

def write_tfrecord(rows, base):
    np.random.shuffle(rows)
    ntr = int(0.7 * len(rows))
    nvl = int(0.15 * len(rows))
    splits = {'train': rows[:ntr], 'val': rows[ntr:ntr+nvl], 'test': rows[ntr+ntr+nvl:]} if False else {'train': rows[:ntr], 'val': rows[ntr:ntr+nvl], 'test': rows[ntr+nvl:]}
    for ph, ch in splits.items():
        fn = f"{base}_{ph}.tfrecord"
        with tf.io.TFRecordWriter(fn) as writer:
            for feat, lbl in ch:
                ex = tf.train.Example(features=tf.train.Features(feature={
                    'features': tf.train.Feature(int64_list=tf.train.Int64List(value=feat)),
                    'label':    tf.train.Feature(int64_list=tf.train.Int64List(value=[lbl]))
                }))
                writer.write(ex.SerializeToString())

# Create/check TFRecords
expected = []
for a in datasets:
    for ph in ('train', 'val', 'test'):
        expected.append(f"{a}_{ph}.tfrecord")
        if a != 'parsed_dataset':
            expected.append(f"Normal_{a}_{ph}.tfrecord")
if not all(os.path.exists(f) for f in expected):
    print("[DATA] TFRecords missing, preprocessing...")
    for a in datasets:
        src = csv_map[a]
        if not os.path.exists(src):
            print(f"[WARN] {src} not found")
        else:
            rows    = preprocess_windows(src)
            normals = [r for r in rows if r[1] == 0]
            attacks = [r for r in rows if r[1] == 1]
            write_tfrecord(normals, f"Normal_{a}")
            if attacks:
                write_tfrecord(attacks, a)
else:
    print("[DATA] All TFRecords found.")

# ------------------------------------------------------------
# 2) tf.data pipeline
# ------------------------------------------------------------
def parse_feat(proto):
    feat = tf.io.parse_single_example(proto, {
        'features': tf.io.FixedLenFeature([FEATURE_DIM], tf.int64),
        'label':    tf.io.FixedLenFeature([1], tf.int64)
    })
    x = tf.cast(feat['features'], tf.float32)
    y = tf.one_hot(tf.cast(feat['label'][0], tf.int32), N_LABELS)
    return x, y

train_files = glob.glob('Normal_*_train.tfrecord')
train_ds = (
    tf.data.TFRecordDataset(train_files, num_parallel_reads=tf.data.AUTOTUNE)
    .map(parse_feat, tf.data.AUTOTUNE)
    .map(lambda x, y: (x + tf.random.normal(tf.shape(x), 0, 0.01), x, y), tf.data.AUTOTUNE)
    .shuffle(10000).repeat()
    .batch(BATCH).prefetch(tf.data.AUTOTUNE)
)
total = sum(1 for _ in tf.data.TFRecordDataset(train_files))
steps_per_epoch = total // BATCH
print(f"[PIPE] Total records: {total}, steps/epoch: {steps_per_epoch}")

# ------------------------------------------------------------
# 3) AAE Model definition
# ------------------------------------------------------------
class AAE(tf.keras.Model):
    def __init__(self):
        super().__init__()
        def dense_block(units):
            layers = [tf.keras.layers.Dense(units)]
            if NORM_TYPE == 'layer': layers.append(tf.keras.layers.LayerNormalization())
            elif NORM_TYPE == 'batch': layers.append(tf.keras.layers.BatchNormalization())
            layers.append(tf.keras.layers.Activation(ACTIVATION))
            if DROPOUT > 0: layers.append(tf.keras.layers.Dropout(DROPOUT))
            return tf.keras.Sequential(layers)

        self.e1   = dense_block(N_L1)
        self.e2   = dense_block(N_L2)
        self.ez   = tf.keras.layers.Dense(LATENT_DIM)
        self.ey   = tf.keras.layers.Dense(N_LABELS)

        self.d1   = dense_block(N_L2)
        self.d2   = dense_block(N_L1)
        self.dout = tf.keras.layers.Dense(FEATURE_DIM, activation='sigmoid')

        self.dz1  = dense_block(N_L1)
        self.dz2  = dense_block(N_L2)
        self.dzout= tf.keras.layers.Dense(1)

        self.dy1  = dense_block(N_L1)
        self.dy2  = dense_block(N_L2)
        self.dyout= tf.keras.layers.Dense(1)

    def encode(self, x):
        h      = self.e2(self.e1(x))
        z      = self.ez(h)
        logits = self.ey(h)
        return z, tf.nn.softmax(logits), logits

    def decode(self, z, y):
        h = tf.concat([z, y], axis=1)
        h = self.d1(h)
        h = self.d2(h)
        return self.dout(h)

    def discriminate_z(self, z):
        h = self.dz1(z)
        h = self.dz2(h)
        return self.dzout(h)

    def discriminate_y(self, y):
        h = self.dy1(y)
        h = self.dy2(h)
        return self.dyout(h)

    def gradient_penalty(self, f, real, fake):
        alpha = tf.random.uniform([real.shape[0], 1], 0, 1)
        interm = real + alpha * (fake - real)
        with tf.GradientTape() as tape:
            tape.watch(interm)
            pred = f(interm)
        grads = tape.gradient(pred, interm)
        slopes= tf.sqrt(tf.reduce_sum(tf.square(grads), axis=1) + 1e-8)
        return tf.reduce_mean((slopes - 1)**2)

aae = AAE()

# ------------------------------------------------------------
# 4) Losses & Optimizers
# ------------------------------------------------------------
mse    = tf.keras.losses.MeanSquaredError()
ce     = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
opt_ae = tf.keras.optimizers.Adam(LR_AE)
opt_dz = tf.keras.optimizers.Adam(LR_DZ)
opt_dy = tf.keras.optimizers.Adam(LR_DY)
opt_g  = tf.keras.optimizers.Adam(LR_G)

# Lists to track losses
train_re_losses = []
val_re_losses   = []
train_dz_losses = []
train_dy_losses = []
train_g_losses  = []

@tf.function
def train_step(xn, xc, y):
    with tf.GradientTape() as t_ae:
        z, yp, logits = aae.encode(xn)
        xr = aae.decode(z, yp)
        loss_re = mse(xc, xr)
    vars_ae = aae.e1.trainable_variables + aae.e2.trainable_variables + aae.ez.trainable_variables + aae.ey.trainable_variables + aae.d1.trainable_variables + aae.d2.trainable_variables + aae.dout.trainable_variables
    grads_ae = t_ae.gradient(loss_re, vars_ae)
    opt_ae.apply_gradients(zip(grads_ae, vars_ae))

    with tf.GradientTape() as t_dz:
        z_real = tf.random.normal([xn.shape[0], LATENT_DIM])
        dz_r = aae.discriminate_z(z_real)
        dz_f = aae.discriminate_z(z)
        gp   = aae.gradient_penalty(aae.discriminate_z, z_real, z)
        loss_dz = tf.reduce_mean(dz_f) - tf.reduce_mean(dz_r) + λ_gp * gp
    vars_dz = aae.dz1.trainable_variables + aae.dz2.trainable_variables + aae.dzout.trainable_variables
    grads_dz = t_dz.gradient(loss_dz, vars_dz)
    opt_dz.apply_gradients(zip(grads_dz, vars_dz))

    with tf.GradientTape() as t_dy:
        dy_r = aae.discriminate_y(y)
        _, yp_enc, _ = aae.encode(xc)
        dy_f = aae.discriminate_y(yp_enc)
        gp_y = aae.gradient_penalty(aae.discriminate_y, y, yp_enc)
        loss_dy = tf.reduce_mean(dy_f) - tf.reduce_mean(dy_r) + λ_gp * gp_y
    vars_dy = aae.dy1.trainable_variables + aae.dy2.trainable_variables + aae.dyout.trainable_variables
    grads_dy = t_dy.gradient(loss_dy, vars_dy)
    opt_dy.apply_gradients(zip(grads_dy, vars_dy))

    with tf.GradientTape() as t_g:
        z_enc, y_enc, logits_enc = aae.encode(xc)
        loss_g = -tf.reduce_mean(aae.discriminate_z(z_enc))
        loss_g += -tf.reduce_mean(aae.discriminate_y(y_enc))
        loss_g += ce(y, logits_enc)
    vars_g = aae.e1.trainable_variables + aae.e2.trainable_variables + aae.ez.trainable_variables + aae.ey.trainable_variables
    grads_g = t_g.gradient(loss_g, vars_g)
    opt_g.apply_gradients(zip(grads_g, vars_g))

    return loss_re, loss_dz, loss_dy, loss_g

# ------------------------------------------------------------
# 5) Training loop
# ------------------------------------------------------------
for epoch in range(1, EPOCHS + 1):
    print(f"[TRAIN] Epoch {epoch}/{EPOCHS}")
    epoch_re, epoch_dz, epoch_dy, epoch_g = 0, 0, 0, 0
    it = iter(train_ds)
    for step in range(steps_per_epoch):
        xn, xc, y = next(it)
        lr, ldz, ldy, lg = train_step(xn, xc, y)
        epoch_re  += lr.numpy()
        epoch_dz += ldz.numpy()
        epoch_dy += ldy.numpy()
        epoch_g  += lg.numpy()
        if step % 100 == 0:
            print(f" step {step}/{steps_per_epoch} | recon={lr:.4f} dz={ldz:.4f} dy={ldy:.4f} gen={lg:.4f}")

    # average losses
    train_re_losses.append(epoch_re/steps_per_epoch)
    train_dz_losses.append(epoch_dz/steps_per_epoch)
    train_dy_losses.append(epoch_dy/steps_per_epoch)
    train_g_losses.append(epoch_g/steps_per_epoch)

    # validation recon loss
    val_loss, val_steps = 0, 0
    val_files = glob.glob('Normal_*_val.tfrecord')
    for fn in val_files:
        ds_val = tf.data.TFRecordDataset(fn).map(parse_feat).batch(BATCH)
        for x_val, _ in ds_val:
            _, yp, _ = aae.encode(x_val + tf.random.normal(tf.shape(x_val),0,0.01))
            x_rec = aae.decode(*aae.encode(x_val)[0:2])
            val_loss += tf.reduce_mean(mse(x_val, x_rec)).numpy()
            val_steps += 1
    val_re_losses.append(val_loss/val_steps)
    print(f"[VALID] recon={val_re_losses[-1]:.4f}")

# ------------------------------------------------------------
# 6) Save encoder & decoder
# ------------------------------------------------------------
from tensorflow.keras.layers import Input, Activation, Concatenate
from tensorflow.keras.models import Model

enc_in = Input(shape=(FEATURE_DIM,))
h = aae.e2(aae.e1(enc_in))
z_enc = aae.ez(h)
y_logits = aae.ey(h)
y_enc = Activation('softmax')(y_logits)
encoder = Model(enc_in, [z_enc, y_enc], name='aae_encoder')

z_in = Input(shape=(LATENT_DIM,))
y_in = Input(shape=(N_LABELS,))
h2 = aae.d2(aae.d1(Concatenate()([z_in, y_in])))
dec_out = aae.dout(h2)
decoder = Model([z_in, y_in], dec_out, name='aae_decoder')

encoder.save('aae_encoder.keras')
decoder.save('aae_decoder.keras')
print("[SAVE] Encoder & decoder saved")

# ------------------------------------------------------------
# 7) Evaluation
# ------------------------------------------------------------
errs, ys = [], []
for fn in glob.glob('*_test.tfrecord'):
    label = 0 if fn.startswith('Normal_') else 1
    ds_eval = tf.data.TFRecordDataset(fn).map(parse_feat).batch(256)
    for x_batch, _ in ds_eval:
        z_p, y_p = encoder(x_batch)
        x_r = decoder([z_p, y_p])
        e = tf.reduce_mean((x_batch - x_r)**2, axis=1).numpy()
        errs.append(e)
        ys.append(np.full(e.shape, label))
errs = np.concatenate(errs)
ys   = np.concatenate(ys)

fpr, tpr, ths = roc_curve(ys, errs)
roc_auc = auc(fpr, tpr)
opt_idx = np.argmax(tpr - fpr)
opt_thr = ths[opt_idx]

print(f"[RESULT] ROC AUC: {roc_auc:.4f}, Thr: {opt_thr:.6f}, TPR: {tpr[opt_idx]:.3f}, FPR: {fpr[opt_idx]:.3f}")
print("[RESULT] Confusion Matrix:")
cm = confusion_matrix(ys, (errs > opt_thr).astype(int))
print(cm)
print("[RESULT] Classification Report:")
print(classification_report(ys, (errs > opt_thr).astype(int), target_names=['Normal','Attack']))

# ------------------------------------------------------------
# 8) Plotting
# ------------------------------------------------------------
# Reconstruction loss curves
plt.figure()
plt.plot(range(1, EPOCHS+1), train_re_losses, label='Train recon')
plt.plot(range(1, EPOCHS+1), val_re_losses,   label='Val recon')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Reconstruction Loss')
plt.legend()
plt.show()

# Adversarial losses
plt.figure()
plt.plot(range(1, EPOCHS+1), train_dz_losses, label='Disc_z')
plt.plot(range(1, EPOCHS+1), train_dy_losses, label='Disc_y')
plt.plot(range(1, EPOCHS+1), train_g_losses,  label='Generator')
plt.xlabel('Epoch')
plt.ylabel('Wasserstein Loss')
plt.title('Adversarial Losses')
plt.legend()
plt.show()

# ROC curve
plt.figure()
plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.3f}')
plt.plot([0,1],[0,1],'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (Test Set)')
plt.legend()
plt.show()

# Confusion matrix heatmap
plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = ['Normal','Attack']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, cm[i, j], horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.show()

# Error distribution histogram
plt.figure()
plt.hist(errs[ys==0], bins=50, alpha=0.5, label='Normal')
plt.hist(errs[ys==1], bins=50, alpha=0.5, label='Attack')
plt.xlabel('Reconstruction error')
plt.ylabel('Count')
plt.title('Error Distribution')
plt.legend()
plt.show()
