In [49]:
!pip install scapy numpy torch



In [50]:

import os
from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np
import torch
from scapy.all import rdpcap, IP, TCP, Raw


PCAP_PATH = "/content/Modbus_polling_only_6RTU(2).pcap"

SEQ_LEN = 32
MIN_PACKETS_PER_WINDOW = 4
TRAIN_SPLIT = 0.8


MAX_DELTA_T = 1.0
MAX_PKT_LEN = 1500.0
MIN_PKT_LEN = 1
MAX_UNIT_ID = 255.0
MAX_ADDR = 65535.0
MAX_QUANTITY = 125.0


def get_flow_key(pkt) -> Tuple[Tuple[str, int], Tuple[str, int]]:

    ip = pkt[IP]
    tcp = pkt[TCP]
    a = (ip.src, int(tcp.sport))
    b = (ip.dst, int(tcp.dport))
    return tuple(sorted([a, b]))


def is_modbus_tcp(pkt) -> bool:

    if IP not in pkt or TCP not in pkt:
        return False
    tcp = pkt[TCP]
    return int(tcp.sport) == 502 or int(tcp.dport) == 502


def parse_modbus_fields(payload: bytes):

    unit_id = 0
    func_code = 0
    addr = 0
    quantity = 0
    is_exception = 0

    if len(payload) < 8:
        return unit_id, func_code, addr, quantity, is_exception

    unit_id = payload[6]
    func_code = payload[7]

    is_exception = 1 if func_code >= 0x80 else 0

    if len(payload) >= 12 and func_code in {1, 2, 3, 4, 5, 6, 15, 16}:
        addr = int.from_bytes(payload[8:10], byteorder="big", signed=False)
        quantity = int.from_bytes(payload[10:12], byteorder="big", signed=False)

    return unit_id, func_code, addr, quantity, is_exception


print(f"Loading pcap from {PCAP_PATH}...")
packets = rdpcap(PCAP_PATH)
print(f"Total packets in pcap: {len(packets)}")


modbus_packets = [p for p in packets if is_modbus_tcp(p)]
print(f"Modbus/TCP packets: {len(modbus_packets)}")


flows: Dict[Tuple[Tuple[str, int], Tuple[str, int]], List[Dict]] = defaultdict(list)
func_codes_seen = set()

for pkt in modbus_packets:
    if IP not in pkt or TCP not in pkt:
        continue

    ip = pkt[IP]
    tcp = pkt[TCP]
    t = float(pkt.time)
    pkt_len = int(len(pkt))

    if int(tcp.dport) == 502:
        direction = +1.0
    elif int(tcp.sport) == 502:
        direction = -1.0
    else:

        direction = 0.0

    if Raw in pkt:
        payload = bytes(pkt[Raw].load)
    else:
        payload = b""

    unit_id, func_code, addr, quantity, is_exception = parse_modbus_fields(payload)
    func_codes_seen.add(func_code)

    key = get_flow_key(pkt)
    flows[key].append(
        {
            "time": t,
            "dir": direction,
            "pkt_len": pkt_len,
            "unit_id": unit_id,
            "func_code": func_code,
            "addr": addr,
            "quantity": quantity,
            "is_exception": is_exception,
        }
    )

print(f"Number of TCP flows with Modbus: {len(flows)}")
print(f"Unique Modbus function codes seen: {sorted(func_codes_seen)}")

FUNC_CODES = sorted(fc for fc in func_codes_seen if fc != 0)
fc_to_idx = {fc: i for i, fc in enumerate(FUNC_CODES)}
NUM_FC = len(FUNC_CODES)

print(f"Function code, one-hot index mapping:")
for fc, idx in fc_to_idx.items():
    print(f"       FC {fc} -> index {idx}")


IDX_DIR        = 0
IDX_DT_NORM    = 1
IDX_LEN_NORM   = 2
IDX_UNIT_NORM  = 3
IDX_ADDR_NORM  = 4
IDX_QTY_NORM   = 5
IDX_IS_EXC     = 6
IDX_IS_MODBUS  = 7
IDX_FC_START   = 8

BASE_FEATS = 8
FEATURE_DIM = BASE_FEATS + NUM_FC


all_sequences = []

for key, pkt_list in flows.items():

    pkt_list = sorted(pkt_list, key=lambda x: x["time"])

    flow_feats = []
    prev_t = pkt_list[0]["time"] if pkt_list else 0.0

    for info in pkt_list:
        t = info["time"]
        delta_t = t - prev_t
        prev_t = t

        if delta_t < 0:
            delta_t = 0.0
        delta_t = min(delta_t, MAX_DELTA_T)
        delta_t_norm = delta_t / MAX_DELTA_T

        pkt_len_norm = min(info["pkt_len"], MAX_PKT_LEN) / MAX_PKT_LEN
        unit_id_norm = info["unit_id"] / MAX_UNIT_ID if MAX_UNIT_ID > 0 else 0.0
        addr_norm    = info["addr"] / MAX_ADDR       if MAX_ADDR    > 0 else 0.0
        quantity_norm = min(info["quantity"], MAX_QUANTITY) / MAX_QUANTITY

        is_exception = float(info["is_exception"])
        direction    = float(info["dir"])

        fc = info["func_code"]
        has_modbus = fc in fc_to_idx

        if not has_modbus:
            continue

        vec = np.zeros(FEATURE_DIM, dtype=np.float32)

        vec[IDX_DIR]       = direction
        vec[IDX_DT_NORM]   = delta_t_norm
        vec[IDX_LEN_NORM]  = pkt_len_norm
        vec[IDX_UNIT_NORM] = unit_id_norm
        vec[IDX_ADDR_NORM] = addr_norm
        vec[IDX_QTY_NORM]  = quantity_norm
        vec[IDX_IS_EXC]    = is_exception
        vec[IDX_IS_MODBUS] = 1.0

        idx = fc_to_idx[fc]
        vec[IDX_FC_START + idx] = 1.0

        flow_feats.append(vec)




    n_packets = len(flow_feats)
    if n_packets == 0:
        continue

    for start in range(0, n_packets, SEQ_LEN):
        window = flow_feats[start:start + SEQ_LEN]
        if len(window) < MIN_PACKETS_PER_WINDOW:
            continue

        pad_rows = SEQ_LEN - len(window)
        if pad_rows > 0:
            pad_vec = np.zeros(FEATURE_DIM, dtype=np.float32)
            pad_vec[:BASE_FEATS] = -1.0
            pad_block = [pad_vec.copy() for _ in range(pad_rows)]
            window = window + pad_block

        seq = np.stack(window, axis=0)
        all_sequences.append(seq)

all_sequences = np.stack(all_sequences, axis=0) if all_sequences else np.zeros(
    (0, SEQ_LEN, FEATURE_DIM), dtype=np.float32
)

print(f"Built {all_sequences.shape[0]} sequences of shape "
      f"{SEQ_LEN} x {FEATURE_DIM}")


num_sequences = all_sequences.shape[0]
indices = np.arange(num_sequences)
np.random.shuffle(indices)

train_cutoff = int(TRAIN_SPLIT * num_sequences)
train_idx = indices[:train_cutoff]
val_idx = indices[train_cutoff:]

train_data = all_sequences[train_idx]
val_data = all_sequences[val_idx]

print(f"Train sequences: {train_data.shape[0]}")
print(f"Val sequences:   {val_data.shape[0]}")

train_tensor = torch.from_numpy(train_data)
val_tensor = torch.from_numpy(val_data)

os.makedirs("prepared_data", exist_ok=True)
torch.save(train_tensor, os.path.join("prepared_data", "modbus_train.pt"))
torch.save(val_tensor, os.path.join("prepared_data", "modbus_val.pt"))

metadata = {
    "seq_len": SEQ_LEN,
    "feature_dim": FEATURE_DIM,
    "base_feats": [
        "direction",
        "delta_t_norm",
        "pkt_len_norm",
        "unit_id_norm",
        "addr_norm",
        "quantity_norm",
        "is_exception",
        "is_modbus",
    ],
    "func_codes": FUNC_CODES,
    "fc_to_idx": fc_to_idx,
    "normalization": {
        "MAX_DELTA_T": MAX_DELTA_T,
        "MAX_PKT_LEN": MAX_PKT_LEN,
        "MAX_UNIT_ID": MAX_UNIT_ID,
        "MAX_ADDR": MAX_ADDR,
        "MAX_QUANTITY": MAX_QUANTITY,
    },
    "train_split": TRAIN_SPLIT,
    "min_packets_per_window": MIN_PACKETS_PER_WINDOW,
}

torch.save(metadata, os.path.join("prepared_data", "modbus_metadata.pt"))
print("Saved prepared_data/modbus_train.pt, modbus_val.pt, modbus_metadata.pt")


[INFO] Loading pcap from /content/Modbus_polling_only_6RTU(2).pcap...
[INFO] Total packets in pcap: 58325
[INFO] Modbus/TCP packets: 57834
[INFO] Number of TCP flows with Modbus: 6087
[INFO] Unique Modbus function codes seen: [0, 1, 2, 3]
[INFO] Function code â†’ one-hot index mapping (excluding 0):
       FC 1 -> index 0
       FC 2 -> index 1
       FC 3 -> index 2
[INFO] Built 338 sequences of shape 32 x 11
[INFO] Train sequences: 270
[INFO] Val sequences:   68
[INFO] Saved prepared_data/modbus_train.pt, modbus_val.pt, modbus_metadata.pt


In [51]:
import os
import torch
from torch.utils.data import TensorDataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [52]:


train_path = "prepared_data/modbus_train.pt"
val_path   = "prepared_data/modbus_val.pt"
meta_path  = "prepared_data/modbus_metadata.pt"

train_tensor = torch.load(train_path)
val_tensor   = torch.load(val_path)
metadata     = torch.load(meta_path)

SEQ_LEN     = metadata["seq_len"]
FEATURE_DIM = metadata["feature_dim"]

FUNC_CODES = metadata["func_codes"]
NUM_FC     = len(FUNC_CODES)

print("Train tensor:", train_tensor.shape)
print("Val tensor:  ", val_tensor.shape)
print("SEQ_LEN:", SEQ_LEN, "FEATURE_DIM:", FEATURE_DIM)
print("Function codes:", FUNC_CODES)


Train tensor: torch.Size([270, 32, 11])
Val tensor:   torch.Size([68, 32, 11])
SEQ_LEN: 32 FEATURE_DIM: 11
Function codes: [1, 2, 3]


In [53]:
BATCH_SIZE = 128

train_ds = TensorDataset(train_tensor)
val_ds   = TensorDataset(val_tensor)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)


In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F

T = 200

def make_beta_schedule(T, start=1e-4, end=0.02):
    return torch.linspace(start, end, T)

betas = make_beta_schedule(T).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)

sqrt_alphas_cumprod     = torch.sqrt(alphas_cumprod)
sqrt_one_minus_ac       = torch.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas       = torch.sqrt(1.0 / alphas)
posterior_variance      = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)


def extract(a, t, x_shape):
    out = a.gather(-1, t)
    return out.reshape(-1, 1, 1).expand(x_shape)


def q_sample(x0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x0)
    sqrt_a_bar = extract(sqrt_alphas_cumprod, t, x0.shape)
    sqrt_one_minus = extract(sqrt_one_minus_ac, t, x0.shape)
    return sqrt_a_bar * x0 + sqrt_one_minus * noise



In [55]:
import math

def timestep_embedding(timesteps, dim):
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(0, half, device=timesteps.device).float() / half
    )
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0,1))
    return emb

class ModbusDiffusionModel(nn.Module):
    def __init__(
        self,
        seq_len,
        feature_dim,
        hidden_dim=128,
        num_layers=4,
        num_heads=4,
    ):
        super().__init__()
        self.seq_len = seq_len
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim


        self.input_proj = nn.Linear(feature_dim, hidden_dim)


        self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, hidden_dim))


        self.time_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)


        self.output_proj = nn.Linear(hidden_dim, feature_dim)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.input_proj.weight)
        nn.init.zeros_(self.input_proj.bias)
        nn.init.xavier_uniform_(self.output_proj.weight)
        nn.init.zeros_(self.output_proj.bias)

    def forward(self, x, t):
        """
        x: (B, S, F)
        t: (B,) long
        """
        B, S, F = x.shape
        assert S == self.seq_len


        t_emb = timestep_embedding(t, self.hidden_dim)
        t_emb = self.time_mlp(t_emb)
        t_emb = t_emb[:, None, :]


        h = self.input_proj(x)
        h = h + self.pos_emb + t_emb


        h = self.transformer(h)


        eps_pred = self.output_proj(h)
        return eps_pred



In [56]:
model = ModbusDiffusionModel(
    seq_len=SEQ_LEN,
    feature_dim=FEATURE_DIM,
    hidden_dim=128,
    num_layers=4,
    num_heads=4,
).to(device)

sum(p.numel() for p in model.parameters()) / 1e6



0.833163

In [57]:
import tqdm

EPOCHS = 50
LR = 1e-4

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)


In [58]:
def train_one_epoch(model, loader, optimizer, epoch):
    model.train()
    total_loss = 0.0
    num_batches = 0

    for (x0,) in tqdm.tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False):
        x0 = x0.to(device)

        B = x0.size(0)

        t = torch.randint(low=0, high=T, size=(B,), device=device, dtype=torch.long)

        noise = torch.randn_like(x0)
        x_t = q_sample(x0, t, noise)

        eps_pred = model(x_t, t)
        loss = F.mse_loss(eps_pred, noise)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / max(num_batches, 1)


@torch.no_grad()
def eval_one_epoch(model, loader, epoch):
    model.eval()
    total_loss = 0.0
    num_batches = 0

    for (x0,) in tqdm.tqdm(loader, desc=f"Epoch {epoch} [val]", leave=False):
        x0 = x0.to(device)
        B = x0.size(0)
        t = torch.randint(low=0, high=T, size=(B,), device=device, dtype=torch.long)

        noise = torch.randn_like(x0)
        x_t = q_sample(x0, t, noise)

        eps_pred = model(x_t, t)
        loss = F.mse_loss(eps_pred, noise)

        total_loss += loss.item()
        num_batches += 1

    return total_loss / max(num_batches, 1)


In [59]:
best_val = float("inf")
ckpt_path = "modbus_diffusion.pt"

for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, epoch)
    val_loss   = eval_one_epoch(model, val_loader, epoch)

    print(f"Epoch {epoch:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

    if val_loss < best_val:
        best_val = val_loss
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "metadata": metadata,
                "betas": betas,
            },
            ckpt_path,
        )
        print(f"  -> Saved new best model to {ckpt_path}")




Epoch 001 | train_loss=2.2248 | val_loss=1.6959
  -> Saved new best model to modbus_diffusion.pt




Epoch 002 | train_loss=1.6988 | val_loss=1.5259
  -> Saved new best model to modbus_diffusion.pt




Epoch 003 | train_loss=1.5267 | val_loss=1.3436
  -> Saved new best model to modbus_diffusion.pt




Epoch 004 | train_loss=1.3622 | val_loss=1.1879
  -> Saved new best model to modbus_diffusion.pt




Epoch 005 | train_loss=1.2218 | val_loss=1.0966
  -> Saved new best model to modbus_diffusion.pt




Epoch 006 | train_loss=1.1443 | val_loss=1.0225
  -> Saved new best model to modbus_diffusion.pt




Epoch 007 | train_loss=1.0723 | val_loss=0.9795
  -> Saved new best model to modbus_diffusion.pt




Epoch 008 | train_loss=1.0227 | val_loss=0.9035
  -> Saved new best model to modbus_diffusion.pt




Epoch 009 | train_loss=0.9649 | val_loss=0.8676
  -> Saved new best model to modbus_diffusion.pt




Epoch 010 | train_loss=0.9340 | val_loss=0.8123
  -> Saved new best model to modbus_diffusion.pt




Epoch 011 | train_loss=0.8964 | val_loss=0.7992
  -> Saved new best model to modbus_diffusion.pt




Epoch 012 | train_loss=0.8703 | val_loss=0.7976
  -> Saved new best model to modbus_diffusion.pt




Epoch 013 | train_loss=0.8467 | val_loss=0.7521
  -> Saved new best model to modbus_diffusion.pt




Epoch 014 | train_loss=0.8235 | val_loss=0.7058
  -> Saved new best model to modbus_diffusion.pt




Epoch 015 | train_loss=0.7846 | val_loss=0.6963
  -> Saved new best model to modbus_diffusion.pt




Epoch 016 | train_loss=0.7579 | val_loss=0.6448
  -> Saved new best model to modbus_diffusion.pt




Epoch 017 | train_loss=0.7418 | val_loss=0.6192
  -> Saved new best model to modbus_diffusion.pt




Epoch 018 | train_loss=0.7072 | val_loss=0.6213




Epoch 019 | train_loss=0.6634 | val_loss=0.5812
  -> Saved new best model to modbus_diffusion.pt




Epoch 020 | train_loss=0.6402 | val_loss=0.5422
  -> Saved new best model to modbus_diffusion.pt




Epoch 021 | train_loss=0.6264 | val_loss=0.5248
  -> Saved new best model to modbus_diffusion.pt




Epoch 022 | train_loss=0.6236 | val_loss=0.5242
  -> Saved new best model to modbus_diffusion.pt




Epoch 023 | train_loss=0.5797 | val_loss=0.4341
  -> Saved new best model to modbus_diffusion.pt




Epoch 024 | train_loss=0.5417 | val_loss=0.4230
  -> Saved new best model to modbus_diffusion.pt




Epoch 025 | train_loss=0.5179 | val_loss=0.4464




Epoch 026 | train_loss=0.5105 | val_loss=0.3915
  -> Saved new best model to modbus_diffusion.pt




Epoch 027 | train_loss=0.4936 | val_loss=0.3996




Epoch 028 | train_loss=0.4468 | val_loss=0.3819
  -> Saved new best model to modbus_diffusion.pt




Epoch 029 | train_loss=0.4836 | val_loss=0.3253
  -> Saved new best model to modbus_diffusion.pt




Epoch 030 | train_loss=0.4408 | val_loss=0.3473




Epoch 031 | train_loss=0.4124 | val_loss=0.3644




Epoch 032 | train_loss=0.3874 | val_loss=0.3217
  -> Saved new best model to modbus_diffusion.pt




Epoch 033 | train_loss=0.3888 | val_loss=0.3094
  -> Saved new best model to modbus_diffusion.pt




Epoch 034 | train_loss=0.3845 | val_loss=0.3085
  -> Saved new best model to modbus_diffusion.pt


                                                     

Epoch 035 | train_loss=0.3921 | val_loss=0.2916




  -> Saved new best model to modbus_diffusion.pt




Epoch 036 | train_loss=0.3661 | val_loss=0.2532
  -> Saved new best model to modbus_diffusion.pt




Epoch 037 | train_loss=0.3572 | val_loss=0.2970


                                                     

Epoch 038 | train_loss=0.3434 | val_loss=0.2505




  -> Saved new best model to modbus_diffusion.pt




Epoch 039 | train_loss=0.3414 | val_loss=0.2575




Epoch 040 | train_loss=0.3211 | val_loss=0.2676




Epoch 041 | train_loss=0.3021 | val_loss=0.2724




Epoch 042 | train_loss=0.3333 | val_loss=0.1845
  -> Saved new best model to modbus_diffusion.pt




Epoch 043 | train_loss=0.3202 | val_loss=0.2316




Epoch 044 | train_loss=0.3195 | val_loss=0.2022




Epoch 045 | train_loss=0.3137 | val_loss=0.2152




Epoch 046 | train_loss=0.3087 | val_loss=0.1998




Epoch 047 | train_loss=0.2717 | val_loss=0.1871




Epoch 048 | train_loss=0.2955 | val_loss=0.2360




Epoch 049 | train_loss=0.3206 | val_loss=0.2148


                                                     

Epoch 050 | train_loss=0.2985 | val_loss=0.2281




In [60]:
ckpt = torch.load("modbus_diffusion.pt", map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
betas = ckpt["betas"].to(device)

alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)

sqrt_alphas_cumprod     = torch.sqrt(alphas_cumprod)
sqrt_one_minus_ac       = torch.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas       = torch.sqrt(1.0 / alphas)
posterior_variance      = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
model.to(device).eval()


ModbusDiffusionModel(
  (input_proj): Linear(in_features=11, out_features=128, bias=True)
  (time_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): SiLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_proj): Linear(in_feat

In [61]:
@torch.no_grad()
def p_sample(model, x_t, t):

    betas_t = extract(betas, t, x_t.shape)
    sqrt_one_minus_ac_t = extract(sqrt_one_minus_ac, t, x_t.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x_t.shape)
    ac_t = extract(alphas_cumprod, t, x_t.shape)
    ac_prev_t = extract(alphas_cumprod_prev, t, x_t.shape)

    eps_theta = model(x_t, t)


    mean = sqrt_recip_alphas_t * (
        x_t - betas_t / sqrt_one_minus_ac_t * eps_theta
    )


    noise = torch.randn_like(x_t)
    mask = (t > 0).float().reshape(-1, 1, 1)
    var = extract(posterior_variance, t, x_t.shape)
    sample = mean + mask * torch.sqrt(var) * noise
    return sample


@torch.no_grad()
def sample_sequences(model, num_samples):

    model.eval()
    x_t = torch.randn(num_samples, SEQ_LEN, FEATURE_DIM, device=device)

    for step in reversed(range(T)):
        t = torch.full((num_samples,), step, device=device, dtype=torch.long)
        x_t = p_sample(model, x_t, t)

    return x_t.cpu()


In [62]:
num_samples = 64
synthetic = sample_sequences(model, num_samples)
synthetic.shape


torch.Size([64, 32, 11])

In [63]:
import torch
import numpy as np

BASE_FEATS = len(metadata["base_feats"])
FUNC_CODES = metadata["func_codes"]
NUM_FC     = len(FUNC_CODES)
norms      = metadata["normalization"]

MAX_DELTA_T    = norms["MAX_DELTA_T"]
MAX_PKT_LEN    = norms["MAX_PKT_LEN"]
MAX_UNIT_ID    = norms["MAX_UNIT_ID"]
MAX_ADDR       = norms["MAX_ADDR"]
MAX_QUANTITY   = norms["MAX_QUANTITY"]


SEQ_LEN = metadata["seq_len"]
FEATURE_DIM = metadata["feature_dim"]

def is_padding_row(row: np.ndarray, pad_sentinel=-1.0) -> bool:
    return np.allclose(row[:BASE_FEATS], pad_sentinel, atol=1e-5)


In [64]:
def decode_row(row: np.ndarray):
    """
    Given a (FEATURE_DIM,) row from the diffusion output, return a dict of
    denormalized fields or None if it's padding.
    """
    if is_padding_row(row):
        return None

    base      = row[:BASE_FEATS]
    fc_onehot = row[BASE_FEATS:]


    direction_raw   = float(base[0])
    delta_t_norm    = float(base[1])
    pkt_len_norm    = float(base[2])
    unit_id_norm    = float(base[3])
    addr_norm       = float(base[4])
    quantity_norm   = float(base[5])
    is_exception    = float(base[6])
    is_modbus_raw   = float(base[7])


    delta_t_norm  = float(np.clip(delta_t_norm,  0.0, 1.0))
    pkt_len_norm  = float(np.clip(pkt_len_norm,  0.0, 1.0))
    unit_id_norm  = float(np.clip(unit_id_norm,  0.0, 1.0))
    addr_norm     = float(np.clip(addr_norm,     0.0, 1.0))
    quantity_norm = float(np.clip(quantity_norm, 0.0, 1.0))


    direction_raw = float(np.clip(direction_raw, -1.0, 1.0))
    direction     = +1.0 if direction_raw >= 0.0 else -1.0


    is_modbus = 1.0


    delta_t  = delta_t_norm * MAX_DELTA_T
    pkt_len  = pkt_len_norm * MAX_PKT_LEN
    unit_id  = unit_id_norm * MAX_UNIT_ID
    addr     = addr_norm * MAX_ADDR
    quantity = quantity_norm * MAX_QUANTITY


    delta_t  = max(delta_t, 0.0)
    pkt_len  = max(int(round(pkt_len)), 1)
    unit_id  = int(np.clip(round(unit_id),  0, MAX_UNIT_ID))
    addr     = int(np.clip(round(addr),     0, MAX_ADDR))
    quantity = int(np.clip(round(quantity), 1, MAX_QUANTITY))

    is_exception_flag = 1 if is_exception > 0.5 else 0


    if fc_onehot.sum() < 1e-3:

        func_code = FUNC_CODES[0]
    else:
        idx = int(np.argmax(fc_onehot))
        func_code = FUNC_CODES[idx]

    return {
        "direction": direction,
        "delta_t": delta_t,
        "pkt_len": pkt_len,
        "unit_id": unit_id,
        "func_code": func_code,
        "addr": addr,
        "quantity": quantity,
        "is_exception": is_exception_flag,
        "is_modbus": True,
    }

In [65]:
from scapy.all import IP, TCP, Raw, wrpcap

def sequences_to_pcap(
    synthetic_sequences: torch.Tensor,
    metadata,
    output_pcap="synthetic_modbus.pcap",
    client_ip="10.0.0.1",
    server_ip="10.0.0.2",
    client_port=40000,
    server_port=502,
):
    """
    Convert synthetic sequences (B, S, F) into a Modbus/TCP PCAP.

    Each non-padding *Modbus* row -> one Modbus request-like packet.
    We use direction to swap src/dst, and delta_t to set approximate timestamps.
    Rows that are not Modbus (is_modbus == 0) are skipped here.
    """
    sequences = synthetic_sequences.detach().cpu().numpy()
    num_seqs, seq_len, feat_dim = sequences.shape

    packets = []
    tx_id = 0
    current_time = 0.0

    for seq_idx in range(num_seqs):
        prev_time = current_time
        for row in sequences[seq_idx]:
            fields = decode_row(row)
            if fields is None:
                continue


            func_code = fields["func_code"]
            if func_code is None:
                continue

            dt = max(0.0, float(fields["delta_t"]))
            prev_time += dt

            direction = fields["direction"]
            unit_id   = fields["unit_id"]
            addr      = fields["addr"]
            quantity  = fields["quantity"]

            pdu = bytes([
                int(func_code) & 0xFF,
                (int(addr) >> 8) & 0xFF,
                int(addr) & 0xFF,
                (int(quantity) >> 8) & 0xFF,
                int(quantity) & 0xFF,
            ])

            tx_id = int((tx_id + 1) % 65536)
            protocol_id = 0
            length = int(len(pdu) + 1)
            unit_id = int(unit_id) & 0xFF

            mbap = (
                int(tx_id).to_bytes(2, "big") +
                int(protocol_id).to_bytes(2, "big") +
                int(length).to_bytes(2, "big") +
                int(unit_id).to_bytes(1, "big")
            )

            payload = mbap + pdu


            if direction > 0:
                src_ip, dst_ip = client_ip, server_ip
                sport, dport = client_port, server_port
            else:
                src_ip, dst_ip = server_ip, client_ip
                sport, dport = server_port, client_port

            pkt = IP(src=src_ip, dst=dst_ip) / TCP(sport=sport, dport=dport) / Raw(load=payload)
            pkt.time = prev_time
            packets.append(pkt)


        current_time = prev_time + 0.1

    wrpcap(output_pcap, packets)
    print(f"[INFO] Wrote {len(packets)} synthetic Modbus packets to {output_pcap}")


In [66]:
num_samples = 128
synthetic = sample_sequences(model, num_samples=num_samples)

sequences_to_pcap(synthetic, metadata, output_pcap="synthetic_modbus.pcap")


[INFO] Wrote 4096 synthetic Modbus packets to synthetic_modbus.pcap


In [67]:
from scapy.all import rdpcap, IP, TCP, Raw

def check_modbus_pcap(pcap_path, valid_func_codes=None):
    packets = rdpcap(pcap_path)
    total_modbus = 0
    alright = 0
    malformed = 0
    bad_fc = 0

    if valid_func_codes is None:
        valid_func_codes = set()

    for pkt in packets:
        if IP not in pkt or TCP not in pkt or Raw not in pkt:
            continue
        tcp = pkt[TCP]
        if int(tcp.sport) != 502 and int(tcp.dport) != 502:
            continue

        total_modbus += 1
        payload = bytes(pkt[Raw].load)


        if len(payload) < 8 + 1:
            malformed += 1
            continue


        tx_id = int.from_bytes(payload[0:2], "big")
        proto_id = int.from_bytes(payload[2:4], "big")
        length = int.from_bytes(payload[4:6], "big")
        unit_id = payload[6]

        bytes_after_length = len(payload) - 6
        length_ok = (length == bytes_after_length)

        if proto_id != 0 or not length_ok:
            malformed += 1
            continue

        func_code = payload[7]

        if valid_func_codes and func_code not in valid_func_codes:
            bad_fc += 1
        else:
            alright += 1

    print(f"[CHECK] Total Modbus/TCP packets: {total_modbus}")
    print(f"[CHECK]   Valid syntax & known func code: {alright}")
    print(f"[CHECK]   Malformed packets: {malformed}")
    print(f"[CHECK]   Unknown/odd function codes: {bad_fc}")

    return {
        "total_modbus": total_modbus,
        "alright": alright,
        "malformed": malformed,
        "bad_fc": bad_fc,
    }


valid_fcs = set(FUNC_CODES)
results = check_modbus_pcap("synthetic_modbus.pcap", valid_func_codes=valid_fcs)


[CHECK] Total Modbus/TCP packets: 4096
[CHECK]   Valid syntax & known func code: 4096
[CHECK]   Malformed packets: 0
[CHECK]   Unknown/odd function codes: 0


In [68]:
def sequences_to_feature_rows(tensor: torch.Tensor):
    """
    tensor: (N, S, F)
    Returns: numpy array of shape (num_real_rows, F) containing only non-padding rows.
    """
    arr = tensor.detach().cpu().numpy()
    rows = []

    for seq in arr:
        for row in seq:
            if not is_padding_row(row):
                rows.append(row)

    return np.stack(rows, axis=0) if rows else np.zeros((0, FEATURE_DIM), dtype=np.float32)

real_rows = sequences_to_feature_rows(train_tensor)
synthetic_rows = sequences_to_feature_rows(synthetic)

real_rows.shape, synthetic_rows.shape


((1080, 11), (4096, 11))

In [73]:
real_np = real_rows


norms = metadata["normalization"]
MAX_DELTA_T_OLD   = norms["MAX_DELTA_T"]
MAX_PKT_LEN_OLD   = norms["MAX_PKT_LEN"]
MAX_UNIT_ID_OLD   = norms["MAX_UNIT_ID"]
MAX_ADDR_OLD      = norms["MAX_ADDR"]
MAX_QUANTITY_OLD  = norms["MAX_QUANTITY"]

p = 0.99

dt_eff_norm   = np.quantile(real_np[:, 1], p)
len_eff_norm  = np.quantile(real_np[:, 2], p)
uid_eff_norm  = np.quantile(real_np[:, 3], p)
addr_eff_norm = np.quantile(real_np[:, 4], p)
qty_eff_norm  = np.quantile(real_np[:, 5], p)

DELTA_T_MAX_EFF   = max(dt_eff_norm  * MAX_DELTA_T_OLD,  1e-6)
PKT_LEN_MAX_EFF   = max(len_eff_norm * MAX_PKT_LEN_OLD,  1.0)
UNIT_ID_MAX_EFF   = max(uid_eff_norm * MAX_UNIT_ID_OLD,  1.0)
ADDR_MAX_EFF      = max(addr_eff_norm* MAX_ADDR_OLD,     1.0)
QUANTITY_MAX_EFF  = max(qty_eff_norm* MAX_QUANTITY_OLD,  1.0)

print("Effective maxima (raw units):")
print("DELTA_T_MAX_EFF:  ", DELTA_T_MAX_EFF)
print("PKT_LEN_MAX_EFF:  ", PKT_LEN_MAX_EFF)
print("UNIT_ID_MAX_EFF:  ", UNIT_ID_MAX_EFF)
print("ADDR_MAX_EFF:     ", ADDR_MAX_EFF)
print("QUANTITY_MAX_EFF: ", QUANTITY_MAX_EFF)


Effective maxima (raw units):
DELTA_T_MAX_EFF:   0.0068676313
PKT_LEN_MAX_EFF:   71.0
UNIT_ID_MAX_EFF:   1.0
ADDR_MAX_EFF:      2094.0
QUANTITY_MAX_EFF:  125.0


In [69]:
def fc_hist(rows: np.ndarray):
    if rows.shape[0] == 0:
        return np.zeros(NUM_FC, dtype=np.float32)

    is_modbus_mask = rows[:, 7] >= 0.5
    modbus_rows = rows[is_modbus_mask]

    if modbus_rows.shape[0] == 0:
        return np.zeros(NUM_FC, dtype=np.float32)

    fc_onehot = modbus_rows[:, BASE_FEATS:]
    idxs = np.argmax(fc_onehot, axis=1)

    counts = np.bincount(idxs, minlength=NUM_FC)
    probs = counts / counts.sum()
    return probs

real_fc_dist = fc_hist(real_rows)
syn_fc_dist  = fc_hist(synthetic_rows)

print("Function codes:", FUNC_CODES)
print("Real FC dist:     ", real_fc_dist)
print("Synthetic FC dist:", syn_fc_dist)


Function codes: [1, 2, 3]
Real FC dist:      [0.31111111 0.40185185 0.28703704]
Synthetic FC dist: [0.59899329 0.28355705 0.11744966]


In [70]:
def col_stats(rows: np.ndarray, name: str, col_idx: int):
    col = rows[:, col_idx]
    print(f"{name}: mean={col.mean():.4f}, std={col.std():.4f}, "
          f"min={col.min():.4f}, max={col.max():.4f}")

print("Real:")
col_stats(real_rows, "direction", 0)
col_stats(real_rows, "delta_t_norm", 1)
col_stats(real_rows, "pkt_len_norm", 2)
col_stats(real_rows, "unit_id_norm", 3)
col_stats(real_rows, "addr_norm", 4)
col_stats(real_rows, "quantity_norm", 5)
col_stats(real_rows, "is_modbus", 7)

print("\nSynthetic:")
col_stats(synthetic_rows, "direction", 0)
col_stats(synthetic_rows, "delta_t_norm", 1)
col_stats(synthetic_rows, "pkt_len_norm", 2)
col_stats(synthetic_rows, "unit_id_norm", 3)
col_stats(synthetic_rows, "addr_norm", 4)
col_stats(synthetic_rows, "quantity_norm", 5)
col_stats(synthetic_rows, "is_modbus", 7)


Real:
direction: mean=0.0000, std=1.0000, min=-1.0000, max=1.0000
delta_t_norm: mean=0.0007, std=0.0012, min=0.0000, max=0.0198
pkt_len_norm: mean=0.0440, std=0.0015, min=0.0427, max=0.0473
unit_id_norm: mean=0.0039, std=0.0000, min=0.0039, max=0.0039
addr_norm: mean=0.0045, std=0.0110, min=0.0000, max=0.0320
quantity_norm: mean=0.0595, std=0.2012, min=0.0000, max=1.0000
is_modbus: mean=1.0000, std=0.0000, min=1.0000, max=1.0000

Synthetic:
direction: mean=-0.5768, std=0.7130, min=-2.9218, max=4.4400
delta_t_norm: mean=-0.2583, std=0.5801, min=-2.4784, max=3.4079
pkt_len_norm: mean=-0.8007, std=0.5600, min=-3.6276, max=4.0832
unit_id_norm: mean=-0.6051, std=0.4567, min=-2.7923, max=2.1016
addr_norm: mean=-0.7049, std=0.5080, min=-2.7801, max=2.5170
quantity_norm: mean=-0.3147, std=0.5320, min=-2.5233, max=3.3789
is_modbus: mean=-0.1896, std=0.8763, min=-2.5915, max=6.6062


In [71]:
def decoded_feature_rows_from_sequences(tensor: torch.Tensor):

    arr = tensor.detach().cpu().numpy()
    decoded_rows = []

    for seq in arr:
        for row in seq:
            if is_padding_row(row):
                continue
            d = decode_row(row)
            if d is None or not d["is_modbus"]:
                continue


            dt_norm  = d["delta_t"]  / MAX_DELTA_T
            len_norm = d["pkt_len"]  / MAX_PKT_LEN
            uid_norm = d["unit_id"]  / MAX_UNIT_ID if MAX_UNIT_ID > 0 else 0.0
            addr_norm = d["addr"]    / MAX_ADDR    if MAX_ADDR    > 0 else 0.0
            qty_norm  = d["quantity"]/ MAX_QUANTITY

            decoded_rows.append([
                d["direction"],
                dt_norm,
                len_norm,
                uid_norm,
                addr_norm,
                qty_norm,
                float(d["is_modbus"]),
            ])

    return np.array(decoded_rows, dtype=np.float32) if decoded_rows else np.zeros((0,7), dtype=np.float32)

decoded_syn = decoded_feature_rows_from_sequences(synthetic)

print("Real:")
col_stats(real_rows, "direction", 0)
col_stats(real_rows, "delta_t_norm", 1)
col_stats(real_rows, "pkt_len_norm", 2)
col_stats(real_rows, "unit_id_norm", 3)
col_stats(real_rows, "addr_norm", 4)
col_stats(real_rows, "quantity_norm", 5)
col_stats(real_rows, "is_modbus", 7)


print("\n\nDecoded Synthetic:")
col_stats(decoded_syn, "direction", 0)
col_stats(decoded_syn, "delta_t_norm", 1)
col_stats(decoded_syn, "pkt_len_norm", 2)
col_stats(decoded_syn, "unit_id_norm", 3)
col_stats(decoded_syn, "addr_norm", 4)
col_stats(decoded_syn, "quantity_norm", 5)
col_stats(decoded_syn, "is_modbus", 6)


Real:
direction: mean=0.0000, std=1.0000, min=-1.0000, max=1.0000
delta_t_norm: mean=0.0007, std=0.0012, min=0.0000, max=0.0198
pkt_len_norm: mean=0.0440, std=0.0015, min=0.0427, max=0.0473
unit_id_norm: mean=0.0039, std=0.0000, min=0.0039, max=0.0039
addr_norm: mean=0.0045, std=0.0110, min=0.0000, max=0.0320
quantity_norm: mean=0.0595, std=0.2012, min=0.0000, max=1.0000
is_modbus: mean=1.0000, std=0.0000, min=1.0000, max=1.0000


Decoded Synthetic:
direction: mean=-0.7173, std=0.6968, min=-1.0000, max=1.0000
delta_t_norm: mean=0.0954, std=0.2489, min=0.0000, max=1.0000
pkt_len_norm: mean=0.0316, std=0.1494, min=0.0007, max=1.0000
unit_id_norm: mean=0.0284, std=0.1308, min=0.0000, max=1.0000
addr_norm: mean=0.0329, std=0.1477, min=0.0000, max=1.0000
quantity_norm: mean=0.0856, std=0.2152, min=0.0080, max=1.0000
is_modbus: mean=1.0000, std=0.0000, min=1.0000, max=1.0000


In [72]:
check_modbus_pcap("synthetic_modbus.pcap", valid_func_codes=set(FUNC_CODES))


[CHECK] Total Modbus/TCP packets: 4096
[CHECK]   Valid syntax & known func code: 4096
[CHECK]   Malformed packets: 0
[CHECK]   Unknown/odd function codes: 0


{'total_modbus': 4096, 'alright': 4096, 'malformed': 0, 'bad_fc': 0}