# Create Multi-Format Dataset from PCAP Files (Organized by Label)

**Objective:** Process PCAP files from GCS to extract packet payloads and create a balanced dataset with 12k samples per label in Parquet and PNG formats.

## Overview

This notebook:
1. Reads PCAP files from `gs://ai-cyber/datasets/unsw-nb15/pcap/pcaps 17-2-2015/`
2. Extracts packet payloads (first 1500 bytes)
3. Uses CSV files to determine attack types as labels
4. Creates 5-channel image encoding format
5. **Saves data ORGANIZED BY LABEL** for easy selective downloading
6. Outputs in Parquet (for ML) and PNG (for visualization) formats

In [None]:
# Imports
import os
import gc
import json
import struct
import numpy as np
import pandas as pd
import tensorflow as tf
from google.cloud import storage
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import io
import time
from datetime import datetime
import hashlib
from PIL import Image
import pyarrow as pa
import pyarrow.parquet as pq
import subprocess
import warnings
warnings.filterwarnings('ignore')

# Configuration
CONFIG = {
    'bucket_name': 'ai-cyber',
    'input_prefix': 'datasets/unsw-nb15/pcap/pcaps 17-2-2015/',
    'output_prefix': 'datasets/unsw-organized-by-label/',
    'samples_per_class': 12000,
    'payload_bytes': 1500,  # First 1500 bytes of packet
    'test_size': 0.15,
    'val_size': 0.15,
    'random_seed': 42,
    'packets_per_pcap': 100000,  # Process in chunks
    'shard_size': 1000,  # Samples per shard
    'max_workers': 4,
    'save_sample_pngs': 100,  # Save first N PNGs per class
    'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
}

# Image format configurations
IMAGE_FORMATS = {
    '5channel_32x32': {'shape': (32, 32), 'channels': 5, 'method': 'multiview'}
}

np.random.seed(CONFIG['random_seed'])

print("✓ Environment configured")
print(f"✓ Will process PCAP files from: gs://{CONFIG['bucket_name']}/{CONFIG['input_prefix']}")
print(f"✓ Output: gs://{CONFIG['bucket_name']}/{CONFIG['output_prefix']}")
print(f"✓ Target samples per class: {CONFIG['samples_per_class']:,}")
print(f"✓ Data will be ORGANIZED BY LABEL for easy access!")

## PCAP Processing Functions

In [None]:
def read_pcap_packets(pcap_data, max_packets=None):
    """Extract packets from PCAP data"""
    packets = []

    # PCAP global header is 24 bytes
    if len(pcap_data) < 24:
        return packets

    # Check if this is Linux cooked capture (network type 113)
    magic, major, minor, thiszone, sigfigs, snaplen, network = struct.unpack('<IHHiIII', pcap_data[:24])
    is_linux_cooked = (network == 113)
    
    # Skip global header
    offset = 24
    packet_count = 0

    while offset < len(pcap_data):
        # Check if we have enough data for packet header (16 bytes)
        if offset + 16 > len(pcap_data):
            break

        # Read packet header
        ts_sec, ts_usec, incl_len, orig_len = struct.unpack('<IIII', pcap_data[offset:offset+16])
        offset += 16

        # Check if we have the packet data
        if offset + incl_len > len(pcap_data):
            break

        # Store packet start offset for later use
        packet_start = offset
        
        # Extract packet data
        packet_data = pcap_data[offset:offset+incl_len]
        offset += incl_len

        # Extract payload based on capture type
        if is_linux_cooked:
            # Linux cooked capture - different header format
            if len(packet_data) > 16:  # At least Linux cooked header
                # Check if it's IPv4 (0x0800)
                proto = struct.unpack('>H', packet_data[14:16])[0]
                if proto == 0x0800:
                    # IP starts at offset 16 for Linux cooked
                    payload = packet_data[16:]
                    
                    if len(payload) > 20:  # More than just IP header
                        packets.append({
                            'timestamp': ts_sec + ts_usec/1000000,
                            'payload': payload[:CONFIG['payload_bytes']],  # First 1500 bytes
                            'length': incl_len,
                            'offset': packet_start,  # Add offset for full packet access
                            'raw_packet': packet_data  # Store full packet for label matching
                        })
        else:
            # Standard Ethernet capture
            if len(packet_data) > 34:  # At least Ethernet + minimal IP header
                payload = packet_data[14:]  # Skip Ethernet header
                
                if len(payload) > 20:  # More than just IP header
                    packets.append({
                        'timestamp': ts_sec + ts_usec/1000000,
                        'payload': payload[:CONFIG['payload_bytes']],
                        'length': incl_len,
                        'offset': packet_start,
                        'raw_packet': packet_data
                    })

        packet_count += 1
        if max_packets and packet_count >= max_packets:
            break

    return packets

print("✓ PCAP processing functions ready")

In [None]:
# Load UNSW CSV files for labeling
print("🔍 Loading UNSW CSV files for labeling...")

# Download and load flow CSV files - focusing on Feb 17 data (CSV 2, 3, 4)
flow_dfs = []
for i in [2, 3, 4]:  # Skip CSV 1 due to January formatting issues
    csv_path = f'/tmp/UNSW-NB15_{i}.csv'
    if not os.path.exists(csv_path):
        print(f"Downloading UNSW-NB15_{i}.csv...")
        subprocess.run(['gsutil', 'cp', f'gs://ai-cyber/datasets/unsw-nb15/csv/UNSW-NB15_{i}.csv', csv_path], check=True)
    
    # Load with proper column names based on features file
    column_names = ['srcip', 'sport', 'dstip', 'dsport', 'proto', 'state', 'dur', 'sbytes', 'dbytes', 
                    'sttl', 'dttl', 'sloss', 'dloss', 'service', 'Sload', 'Dload', 'Spkts', 'Dpkts',
                    'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len',
                    'Sjit', 'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat',
                    'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd',
                    'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm',
                    'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'attack_cat', 'Label']
    
    df = pd.read_csv(csv_path, names=column_names, low_memory=False)
    flow_dfs.append(df)
    print(f"✓ Loaded {len(df)} flow records from UNSW-NB15_{i}.csv")

# Combine all flow records
all_flows = pd.concat(flow_dfs, ignore_index=True)
print(f"\\n✓ Total flow records loaded: {len(all_flows):,}")

# Filter out records without attack categories
labeled_flows = all_flows[all_flows['attack_cat'].notna() & (all_flows['attack_cat'] != '') & (all_flows['attack_cat'] != ' ')]
print(f"✓ Flow records with attack labels: {len(labeled_flows):,}")

# Add normal flows (those with Label=0 and no attack_cat)
normal_flows = all_flows[(all_flows['Label'] == 0) & (all_flows['attack_cat'].isna() | (all_flows['attack_cat'] == '') | (all_flows['attack_cat'] == ' '))]
normal_flows['attack_cat'] = 'Normal'
labeled_flows = pd.concat([labeled_flows, normal_flows], ignore_index=True)
print(f"✓ Total labeled flows (including Normal): {len(labeled_flows):,}")

# Get attack category distribution
attack_categories = labeled_flows['attack_cat'].value_counts()
print(f"\\n📊 Attack categories found:")
for cat, count in attack_categories.items():
    print(f"   {cat}: {count:,} samples")

# Create lookup structure for fast matching
print("\\n🔧 Building flow lookup structures...")

def parse_port(port_str):
    """Parse port number that might be in hex format"""
    if pd.isna(port_str):
        return 0
    port_str = str(port_str).strip()
    if port_str.startswith('0x'):
        return int(port_str, 16)
    else:
        try:
            return int(port_str)
        except ValueError:
            return 0

# Create a combined lookup key for each flow
flow_lookup = {}
for idx, row in labeled_flows.iterrows():
    try:
        # Parse ports (handle hex format)
        src_port = parse_port(row['sport'])
        dst_port = parse_port(row['dsport'])
        
        # Create lookup keys for both directions with time tolerance
        # Note: CSV times seem to be off by ~1 second from PCAP times
        for time_offset in [-2, -1, 0, 1, 2]:  # Check within 2 seconds
            for t in range(int(row['Stime']) + time_offset, int(row['Ltime']) + time_offset + 1):
                # Key format: (src_ip, src_port, dst_ip, dst_port, proto, time)
                key1 = (row['srcip'], src_port, row['dstip'], dst_port, row['proto'], t)
                key2 = (row['dstip'], dst_port, row['srcip'], src_port, row['proto'], t)
                
                flow_lookup[key1] = row['attack_cat']
                flow_lookup[key2] = row['attack_cat']
    except Exception as e:
        # Skip flows with parsing errors
        continue
    
    if idx % 50000 == 0:
        print(f"   Processed {idx:,} flows...")

print(f"\\n✓ Flow lookup table ready with {len(flow_lookup):,} entries")

def parse_packet_linux_cooked(packet_data):
    """Parse packet from Linux cooked capture format"""
    if len(packet_data) < 16:
        return None, None, None, None, None
    
    # Linux cooked header is 16 bytes
    proto = struct.unpack('>H', packet_data[14:16])[0]
    
    if proto != 0x0800:  # Not IPv4
        return None, None, None, None, None
    
    # IP header starts at offset 16
    ip_data = packet_data[16:]
    if len(ip_data) < 20:
        return None, None, None, None, None
    
    # Extract protocol
    protocol = ip_data[9]
    
    # Extract source and destination IPs
    src_ip = '.'.join(map(str, ip_data[12:16]))
    dst_ip = '.'.join(map(str, ip_data[16:20]))
    
    # Get IP header length
    ip_header_len = (ip_data[0] & 0x0F) * 4
    
    # Parse transport layer for ports
    transport_data = ip_data[ip_header_len:]
    
    src_port = None
    dst_port = None
    
    if protocol in [6, 17] and len(transport_data) >= 4:  # TCP or UDP
        src_port = (transport_data[0] << 8) | transport_data[1]
        dst_port = (transport_data[2] << 8) | transport_data[3]
    
    return src_ip, src_port, dst_ip, dst_port, protocol

def get_label_from_packet(packet_data, timestamp):
    """
    Match packet to flow records using fast lookup
    """
    src_ip, src_port, dst_ip, dst_port, protocol = parse_packet_linux_cooked(packet_data)
    
    if src_ip is None:
        return None
    
    # Convert protocol number to name
    proto_map = {6: 'tcp', 17: 'udp', 1: 'icmp'}
    proto_name = proto_map.get(protocol, str(protocol))
    
    # Create lookup key
    packet_time = int(timestamp)
    
    # Try to find exact match
    if src_port is not None:
        key = (src_ip, src_port, dst_ip, dst_port, proto_name, packet_time)
        if key in flow_lookup:
            return flow_lookup[key]
    
    # If no exact match, try without ports (for ICMP or fragmented packets)
    for port_combo in [(0, 0), (-1, -1)]:
        key = (src_ip, port_combo[0], dst_ip, port_combo[1], proto_name, packet_time)
        if key in flow_lookup:
            return flow_lookup[key]
    
    return None

print("\\n✓ Packet labeling function ready")

## Image Encoding Functions

In [3]:
def hilbert_curve_positions(n):
    """Generate Hilbert curve positions for n×n grid"""
    def hilbert(x, y, xi, xj, yi, yj, n):
        if n <= 0:
            yield x + (xi + yi) // 2, y + (xj + yj) // 2
        else:
            for i in hilbert(x, y, yi//2, yj//2, xi//2, xj//2, n-1):
                yield i
            for i in hilbert(x + xi//2, y + xj//2, xi//2, xj//2, yi//2, yj//2, n-1):
                yield i
            for i in hilbert(x + xi//2 + yi//2, y + xj//2 + yj//2, xi//2, xj//2, yi//2, yj//2, n-1):
                yield i
            for i in hilbert(x + xi//2 + yi, y + xj//2 + yj, -yi//2, -yj//2, -xi//2, -xj//2, n-1):
                yield i

    return list(hilbert(0, 0, n, 0, 0, n, int(np.log2(n))))

def spiral_positions(n):
    """Generate spiral positions for n×n grid"""
    positions = []
    x, y = n // 2, n // 2
    dx, dy = 0, -1

    for _ in range(n * n):
        if 0 <= x < n and 0 <= y < n:
            positions.append((x, y))

        if x == y or (x < 0 and x == -y) or (x > 0 and x == 1 - y):
            dx, dy = -dy, dx
        x, y = x + dx, y + dy

    return positions

def encode_payload_multiformat(payload_bytes, format_config):
    """Encode payload bytes into various image formats"""
    # Convert payload to numpy array of uint8
    if isinstance(payload_bytes, (bytes, bytearray)):
        payload_bytes = np.frombuffer(payload_bytes, dtype=np.uint8)
    else:
        payload_bytes = np.array(payload_bytes, dtype=np.uint8)

    height, width = format_config['shape']
    channels = format_config['channels']
    method = format_config['method']

    # Ensure payload is correct length
    target_pixels = height * width
    if len(payload_bytes) < target_pixels:
        payload_bytes = np.pad(payload_bytes, (0, target_pixels - len(payload_bytes)), 'constant')
    else:
        payload_bytes = payload_bytes[:target_pixels]

    if method == 'sequential':
        # Simple sequential reshape
        image = payload_bytes.reshape(height, width)
        if channels == 1:
            return image.astype(np.float32) / 255.0
        else:
            # Create RGB by repeating grayscale
            image_norm = image.astype(np.float32) / 255.0
            return np.stack([image_norm] * channels, axis=-1)

    elif method == 'hilbert':
        # Hilbert curve mapping
        positions = hilbert_curve_positions(min(height, width))
        image = np.zeros((height, width, 3), dtype=np.float32)

        for i, (x, y) in enumerate(positions[:len(payload_bytes)]):
            if x < height and y < width:
                val = payload_bytes[i] / 255.0
                image[x, y] = [val, val * 0.7, val * 0.5]

        return image

    elif method == 'spiral':
        # Spiral mapping
        positions = spiral_positions(min(height, width))
        image = np.zeros((height, width, 3), dtype=np.float32)

        for i, (x, y) in enumerate(positions[:len(payload_bytes)]):
            if 0 <= x < height and 0 <= y < width:
                val = payload_bytes[i] / 255.0
                image[x, y] = [val * 0.5, val, val * 0.7]

        return image

    elif method == 'multiview':
        # 5-channel representation
        image = np.zeros((height, width, 5), dtype=np.float32)

        # Channel 1: Raw bytes
        image[:, :, 0] = payload_bytes.reshape(height, width) / 255.0

        # Channel 2: Header emphasis (first 64 bytes)
        header_channel = np.zeros(target_pixels)
        header_channel[:64] = payload_bytes[:64] / 255.0
        image[:, :, 1] = header_channel.reshape(height, width)

        # Channel 3: Byte frequency
        byte_freq = np.bincount(payload_bytes.astype(int), minlength=256)
        freq_map = byte_freq[payload_bytes] / (np.max(byte_freq) + 1e-10)
        image[:, :, 2] = freq_map.reshape(height, width)

        # Channel 4: Local entropy
        entropy_map = np.zeros(target_pixels)
        window = 16
        for i in range(0, len(payload_bytes) - window, window):
            window_bytes = payload_bytes[i:i+window]
            # Simplified entropy calculation
            unique, counts = np.unique(window_bytes, return_counts=True)
            probs = counts / window
            entropy = -np.sum(probs * np.log2(probs + 1e-10))
            entropy_map[i:i+window] = entropy / 8  # Normalize by max entropy
        image[:, :, 3] = entropy_map.reshape(height, width)

        # Channel 5: Gradient magnitude
        grad = np.abs(np.diff(payload_bytes.astype(float)))
        grad_padded = np.pad(grad, (0, 1), 'edge')
        image[:, :, 4] = grad_padded.reshape(height, width) / 255.0

        return image

print("✓ Image encoding functions ready")

✓ Image encoding functions ready


## Storage Functions (Parquet and PNG)

In [None]:
# Parquet helper functions
class MultiFormatDataWriter:
    """Writes data in Parquet and PNG formats - ORGANIZED BY LABEL"""

    def __init__(self, bucket, base_path, shard_size):
        self.bucket = bucket
        self.base_path = base_path
        self.shard_size = shard_size
        # Organize by label/format/split
        self.current_shard = defaultdict(list)
        self.shard_counts = defaultdict(int)
        self.png_counts = defaultdict(int)
        self.manifest = {
            'parquet': defaultdict(lambda: defaultdict(list)),
            'png': defaultdict(list)
        }

    def add_sample(self, sample, split, format_name):
        """Add a sample and write to all formats"""
        # Include label in the key
        label = sample['label']
        key = f"{label}/{format_name}/{split}"
        self.current_shard[key].append(sample)

        # Write PNG immediately (for first N samples per class)
        label_key = f"{format_name}/{split}/{label}"
        if self.png_counts[label_key] < CONFIG['save_sample_pngs']:
            self._write_png(sample, split, format_name)
            self.png_counts[label_key] += 1

        # Write shard if full
        if len(self.current_shard[key]) >= self.shard_size:
            self._write_shard(key)

    def _write_shard(self, key):
        """Write a shard in Parquet format"""
        if not self.current_shard[key]:
            return

        # Parse label from key
        label, format_name, split = key.split('/')
        shard_num = self.shard_counts[key]
        samples = self.current_shard[key]

        # Write Parquet shard - ORGANIZED BY LABEL
        parquet_path = f"{self.base_path}parquet/{format_name}/{label}/{split}/shard_{shard_num:05d}.parquet"

        # Prepare data for Parquet
        data = {
            'sample_id': [],
            'label': [],
            'image_format': [],
            'image_data': [],
            'height': [],
            'width': [],
            'channels': [],
            'payload_bytes': []
        }

        for sample in samples:
            image = sample['image']
            data['sample_id'].append(sample['sample_id'])
            data['label'].append(sample['label'])
            data['image_format'].append(format_name)
            data['image_data'].append(image.flatten().tolist())
            data['height'].append(image.shape[0])
            data['width'].append(image.shape[1])
            data['channels'].append(image.shape[2] if len(image.shape) > 2 else 1)
            data['payload_bytes'].append(sample['payload_bytes'].tolist())

        # Create table and save
        table = pa.table(data)
        buffer = io.BytesIO()
        pq.write_table(table, buffer)
        buffer.seek(0)

        blob = self.bucket.blob(parquet_path)
        blob.upload_from_file(buffer)

        self.manifest['parquet'][label][f"{format_name}/{split}"].append({
            'shard_num': shard_num,
            'path': parquet_path,
            'num_samples': len(samples)
        })

        # Clear shard and increment counter
        self.current_shard[key] = []
        self.shard_counts[key] += 1

        print(f"   ✓ Wrote shard {shard_num} for {label}/{format_name}/{split}")

    def _write_png(self, sample, split, format_name):
        """Write individual PNG file"""
        image = sample['image']

        # Normalize to 0-255 range
        if image.dtype == np.float32 or image.dtype == np.float64:
            image_uint8 = (image * 255).astype(np.uint8)
        else:
            image_uint8 = image.astype(np.uint8)

        # Create PIL image
        if len(image_uint8.shape) == 2:
            pil_image = Image.fromarray(image_uint8, mode='L')
        elif image_uint8.shape[2] == 3:
            pil_image = Image.fromarray(image_uint8, mode='RGB')
        else:
            # For 5-channel, save first 3 as RGB
            pil_image = Image.fromarray(image_uint8[:, :, :3], mode='RGB')

        # Save to buffer
        buffer = io.BytesIO()
        pil_image.save(buffer, format='PNG')
        buffer.seek(0)

        # Upload to GCS - ORGANIZED BY LABEL
        path = f"{self.base_path}png/{format_name}/{label}/{split}/{sample['sample_id']}.png"
        blob = self.bucket.blob(path)
        blob.upload_from_file(buffer, content_type='image/png')

    def finalize(self):
        """Write remaining shards and save manifests"""
        # Write remaining shards
        for key in list(self.current_shard.keys()):
            if self.current_shard[key]:
                self._write_shard(key)

        # Save combined manifest
        manifest_data = {
            'timestamp': CONFIG['timestamp'],
            'shard_size': self.shard_size,
            'formats': {
                'parquet': dict(self.manifest['parquet']),
                'png': dict(self.png_counts)
            },
            'total_shards': dict(self.shard_counts),
            'image_formats': IMAGE_FORMATS,
            'labels': list(self.manifest['parquet'].keys())  # List all labels processed
        }

        manifest_blob = self.bucket.blob(f"{self.base_path}manifest.json")
        manifest_blob.upload_from_string(
            json.dumps(manifest_data, indent=2),
            content_type='application/json'
        )

        return manifest_data

print("✓ Multi-format writer ready - NOW ORGANIZED BY LABEL (Parquet + PNG only)!")

✓ Multi-format writer ready - NOW ORGANIZED BY LABEL (Parquet + PNG only)!


## Process PCAP Files

In [None]:
# Initialize GCS
client = storage.Client()
bucket = client.bucket(CONFIG['bucket_name'])

# List all PCAP files
print("🔍 Discovering PCAP files...")
pcap_files = []

# List all blobs in the PCAP directory
all_blobs = list(bucket.list_blobs(prefix=CONFIG['input_prefix']))

# Extract PCAP files (all in one folder for UNSW)
for blob in all_blobs:
    if blob.name.endswith('.pcap'):
        pcap_files.append({
            'path': blob.name,
            'size_mb': blob.size / (1024 * 1024)
        })

print(f"\\n📊 Found {len(pcap_files)} PCAP files")
print(f"Total size: {sum(f['size_mb'] for f in pcap_files):.1f} MB")

# Show first few files
print("\\n📁 Sample files:")
for file_info in pcap_files[:5]:
    print(f"   {file_info['path'].split('/')[-1]}: {file_info['size_mb']:.1f} MB")

In [None]:
# Initialize writer
output_base = f"{CONFIG['output_prefix']}{CONFIG['timestamp']}/"
writer = MultiFormatDataWriter(bucket, output_base, CONFIG['shard_size'])

print("\\n🚀 Processing PCAP files and creating dataset...")
print(f"Target: {CONFIG['samples_per_class']} samples per label\\n")

# Track progress
sample_count = 0
label_counts = Counter()
skipped_count = 0
start_time = time.time()

# Process each PCAP file
for file_info in pcap_files:
    print(f"\\n📦 Processing: {file_info['path'].split('/')[-1]} ({file_info['size_mb']:.1f} MB)")

    try:
        # Download PCAP file
        blob = bucket.blob(file_info['path'])
        pcap_data = blob.download_as_bytes()

        # Extract packets
        packets = read_pcap_packets(pcap_data, max_packets=CONFIG['packets_per_pcap'])
        print(f"   Extracted {len(packets)} packets")

        # Process packets
        packets_processed = 0
        packets_labeled = 0
        
        for packet in packets:
            # Get label from flow lookup using the raw packet data
            label = get_label_from_packet(packet['raw_packet'], packet['timestamp'])
            
            # Skip packets without labels
            if label is None:
                skipped_count += 1
                continue
            
            packets_labeled += 1
            
            # Skip if we already have enough samples for this label
            if label_counts[label] >= CONFIG['samples_per_class']:
                continue

            # Use the pre-extracted payload
            payload = packet['payload']
            if isinstance(payload, (bytes, bytearray)):
                payload_array = np.frombuffer(payload, dtype=np.uint8)
            else:
                payload_array = np.array(payload, dtype=np.uint8)

            # Pad to 1500 bytes if needed
            if len(payload_array) < CONFIG['payload_bytes']:
                payload_array = np.pad(payload_array,
                                     (0, CONFIG['payload_bytes'] - len(payload_array)),
                                     'constant')
            else:
                payload_array = payload_array[:CONFIG['payload_bytes']]

            # Generate sample ID
            sample_id = f"{label}_{label_counts[label]:06d}"

            # Determine split
            rand_val = np.random.random()
            if rand_val < CONFIG['test_size']:
                split = 'test'
            elif rand_val < CONFIG['test_size'] + CONFIG['val_size']:
                split = 'val'
            else:
                split = 'train'

            # Create images for all formats and save
            for format_name, format_config in IMAGE_FORMATS.items():
                image = encode_payload_multiformat(payload_array, format_config)

                # Add sample (will be saved in Parquet and PNG)
                writer.add_sample({
                    'image': image,
                    'label': label,
                    'sample_id': sample_id,
                    'payload_bytes': payload_array
                }, split, format_name)

            label_counts[label] += 1
            sample_count += 1
            packets_processed += 1

            if sample_count % 1000 == 0:
                elapsed = time.time() - start_time
                rate = sample_count / elapsed
                print(f"   Progress: {sample_count:,} total samples ({rate:.0f} samples/sec)")
                print(f"   Label distribution: {dict(label_counts)}")

        # Clear memory
        del pcap_data
        gc.collect()

        print(f"   ✓ Processed {file_info['path'].split('/')[-1]}")
        print(f"      Packets with labels: {packets_labeled}/{len(packets)} ({packets_labeled/len(packets)*100:.1f}%)")
        print(f"      Samples created: {packets_processed}")

    except Exception as e:
        print(f"   ⚠️ Error processing {file_info['path']}: {e}")
        import traceback
        traceback.print_exc()
        continue

    # Check if we have enough samples for all labels
    all_labels_complete = True
    for cat in attack_categories.index:
        if label_counts.get(cat, 0) < CONFIG['samples_per_class']:
            all_labels_complete = False
            break
    
    if all_labels_complete:
        print(f"\\n✓ Reached target samples for all labels, stopping...")
        break

# Finalize
print("\\n💾 Finalizing all storage formats...")
manifest = writer.finalize()

print(f"\\n✅ Dataset creation complete!")
print(f"📁 Location: gs://{CONFIG['bucket_name']}/{output_base}")
print(f"📊 Total samples: {sample_count:,}")
print(f"⏱️ Total time: {(time.time() - start_time)/60:.1f} minutes")
print(f"\\n📈 Samples per label:")
for label, count in sorted(label_counts.items()):
    print(f"   {label}: {count:,} samples")
print(f"\\n⚠️ Skipped {skipped_count:,} packets without labels")

## Summary

This notebook processes PCAP files to create a comprehensive multi-format dataset:

### What's Created:

1. **Parquet Files** (`.parquet`)
   - Columnar format for data analysis and ML training
   - Easy to load into Pandas/PyTorch/TensorFlow
   - Contains: images (flattened), labels, raw payload bytes

2. **PNG Files** (`.png`)
   - Sample images for visualization
   - First 100 samples per class
   - Organized by label for easy browsing

### Dataset Structure - ORGANIZED BY LABEL:
```
gs://ai-cyber/datasets/unsw-organized-by-label/[timestamp]/
├── parquet/
│   ├── 5channel_32x32/
│   │   ├── Normal/
│   │   │   ├── train/
│   │   │   │   ├── shard_00000.parquet
│   │   │   │   └── ...
│   │   │   ├── val/
│   │   │   └── test/
│   │   ├── Generic/
│   │   │   ├── train/
│   │   │   ├── val/
│   │   │   └── test/
│   │   └── Exploits/
│   │       ├── train/
│   │       ├── val/
│   │       └── test/
├── png/
│   ├── 5channel_32x32/
│   │   ├── Normal/
│   │   │   ├── train/
│   │   │   │   ├── Normal_000001.png
│   │   │   │   └── ...
│   │   │   ├── val/
│   │   │   └── test/
│   │   └── [other labels...]
└── manifest.json
```

### Key Features:
- **ORGANIZED BY LABEL** - Each label has its own folder!
- Easy to download specific labels without parsing everything
- Clear structure: format → label → split → shards
- Manifest includes label list for easy discovery
- Balanced sampling with up to 12k samples per label
- Only 5-channel format for optimal performance