In [None]:
# Cell 1 – Mount Drive | Imports | Seed | CONFIG
from pathlib import Path
import os, time, math, pickle, random, warnings
import numpy as np
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from tqdm import tqdm  # Add tqdm for progress tracking
warnings.filterwarnings("ignore")

# Mount Google Drive
def in_colab():
    try:
        import google.colab  # noqa
        return True
    except Exception:
        return False

if in_colab():
    from google.colab import drive
    drive.mount("/content/drive", force_remount=False)

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True

set_seed(42)

# CONFIG Declaration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device count: {torch.cuda.device_count()}")

CONFIG = {
    "DEVICE": DEVICE,
    "SEED": 42,
    "NUM_WORKERS": 0,
    "BATCH_SIZE": 32 if DEVICE == "cpu" else 64,
    "LR": 1e-3,
    "DEFAULT_EPOCHS": 30,
    "EPOCHS_PER_DATASET": {"CIFAR10": 15, "CIFAR100": 30, "CINIC10": 30, "YAHOO":10,"CRITEO":15},
    "KD_TEMPERATURE": 4.0,
    "KD_LAMBDA": 0.7,
    "AT_PGD_STEPS": 5,
    "AT_STEP_SIZE": 2/255,
    "AUX_SIZE": 1000,
    "DATASET_HP": {
        "CIFAR10":  {"tau": 4.0, "k": 3, "eps_k": 0.15, "lr_teacher": 1e-3, "lr_student": 5e-4, "adv_eps": 8/255},
        "CIFAR100": {"tau": 5.0, "k": 5, "eps_k": 0.15, "lr_teacher": 1e-3, "lr_student": 5e-4, "adv_eps": 8/255},
        "CINIC10":  {"tau": 3.5, "k": 3, "eps_k": 0.15, "lr_teacher": 1e-3, "lr_student": 5e-4, "adv_eps": 8/255},
         "YAHOO":    {"tau": 4.0, "k": 5, "eps_k": 0.1, "lr_teacher": 5e-4, "lr_student": 1e-4, "adv_eps": 0.05},
        "CRITEO":   {"tau": 4.0, "k": 5, "eps_k": 0.1, "lr_teacher": 1e-3, "lr_student": 5e-4, "adv_eps": 0.1}
    },

    "SIZES": {
        "CIFAR10":  (50_000, 10_000),
        "CIFAR100": (50_000, 10_000),
        "CINIC10":  (90_000, 90_000),
        "YAHOO":    (50_000, 20_000),
        "CRITEO":   (80_000, 20_000)
    },
    "CLAMP_MIN": -3.0,
    "CLAMP_MAX":  3.0,
    "DRIVE_PATH": "/content",
    "PIN_MEMORY": False,
    "PREFETCH_FACTOR": None,

}

os.makedirs(CONFIG["DRIVE_PATH"], exist_ok=True)
print("Device:", CONFIG["DEVICE"])
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
    torch.backends.cudnn.benchmark = True
    print("cuDNN benchmark enabled for GPU performance.")
else:
    print("WARNING: Running on CPU - training will be slow!")

Mounted at /content/drive
PyTorch CUDA available: True
Current CUDA device: 0
CUDA device count: 1
Device: cuda
GPU: Tesla T4
GPU Memory: 15.83 GB
cuDNN benchmark enabled for GPU performance.


In [None]:
# Cell 2 - VFL Training, LIA/Adversarial Attack and Evaluation Functions

from tqdm import tqdm
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score, roc_auc_score, top_k_accuracy_score
import numpy as np

# Generate_adv_examples function remains the same
def generate_adv_examples(model, xa, xp, y, config, dataset_name):
    hp = config["DATASET_HP"][dataset_name]
    eps = hp["adv_eps"]
    is_vision = dataset_name in ["CIFAR10", "CIFAR100", "CINIC10"]
    xp_adv = xp.clone().detach().requires_grad_(True)
    original_mode = model.training
    model.eval()
    with torch.enable_grad():
        logits = model(xa, xp_adv)
        loss = F.cross_entropy(logits, y)
    grad = torch.autograd.grad(loss, xp_adv, retain_graph=False, create_graph=False)[0]
    model.train(original_mode)
    if is_vision:
        perturbation = eps * grad.sign()
    else:
        grad_norm = torch.norm(grad.view(grad.shape[0], -1), p=2, dim=1).view(-1, 1, 1, 1) if grad.dim() == 4 else torch.norm(grad, p=2, dim=1).view(-1, 1)
        perturbation = eps * grad / (grad_norm + 1e-8)
    xp_perturbed = xp + perturbation
    if is_vision:
        xp_perturbed = torch.clamp(xp_perturbed, config["CLAMP_MIN"], config["CLAMP_MAX"])
    return xp_perturbed.detach()


def train_epoch_unified(model, loader, optimizer, config, dataset_name, teacher=None, use_at=False, use_dp=False, noise_multiplier=0.3, use_embedding_noise=False, embedding_noise_level=0.1, sched=None):
    """A unified training function with correct metric reporting."""
    model.train()
    if teacher: teacher.eval()

    total_loss, total_correct, total_samples = 0, 0, 0 # Initialize metrics
    iterator = tqdm(loader, desc=f'Training {dataset_name}', ncols=110)
    device = config["DEVICE"]
    hp = config["DATASET_HP"][dataset_name]

    for (xa, xp), y in iterator:
        xa, xp, y = xa.to(device), xp.to(device), y.to(device)
        optimizer.zero_grad()

        xp_train = generate_adv_examples(model, xa, xp, y, config, dataset_name) if use_at else xp

        h_a = model.bottom_a(xa)
        h_p = model.bottom_p(xp_train)

        if use_embedding_noise:
            h_p = h_p + torch.randn_like(h_p) * embedding_noise_level

        if use_dp:
            def add_noise_hook(grad):
                noise = torch.randn_like(grad) * noise_multiplier
                return grad + noise
            h_p.register_hook(add_noise_hook)

        logits = model.top_model(h_a, h_p)

        if teacher:
            with torch.no_grad():
                teacher_logits = teacher(xa, xp_train)
            T = hp["tau"]
            kd_loss = F.kl_div(F.log_softmax(logits / T, dim=1), F.log_softmax(teacher_logits / T, dim=1), reduction='batchmean', log_target=True) * (T * T)
            ce_loss = F.cross_entropy(logits, y)
            loss = config["KD_LAMBDA"] * kd_loss + (1 - config["KD_LAMBDA"]) * ce_loss
        else:
            loss = F.cross_entropy(logits, y)

        loss.backward()
        optimizer.step()
        if sched: sched.step()

        total_loss += loss.item() * y.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += y.size(0)
        iterator.set_postfix({'loss': f'{loss.item():.4f}'})

    # Returns calculated average loss and accuracy
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    train_acc = 100 * total_correct / total_samples if total_samples > 0 else 0.0
    return avg_loss, train_acc

def evaluate_model_unified(model, loader, config, dataset_name, robust=False):
    model.eval()
    y_true_list, y_probs_list = [], []
    total_loss, total = 0, 0
    device = config["DEVICE"]
    with torch.no_grad():
        for (xa, xp), y in loader:
            xa, xp, y = xa.to(device), xp.to(device), y.to(device)
            xp_eval = generate_adv_examples(model, xa, xp, y, config, dataset_name) if robust else xp
            logits = model(xa, xp_eval)
            loss = F.cross_entropy(logits, y)
            total_loss += loss.item() * y.size(0)
            total += y.size(0)
            y_true_list.extend(y.cpu().numpy())
            y_probs_list.extend(F.softmax(logits, dim=1).cpu().numpy())
    loss_avg = total_loss / total if total > 0 else 0.0
    y_true, y_probs = np.array(y_true_list), np.array(y_probs_list)
    num_classes = y_probs.shape[1]
    if num_classes == 2:
        preds = np.argmax(y_probs, axis=1)
        top1_acc = (preds == y_true).mean() * 100
        top5_acc = top1_acc
    else:
        top1_acc = top_k_accuracy_score(y_true, y_probs, k=1, labels=range(num_classes)) * 100
        top5_acc = top_k_accuracy_score(y_true, y_probs, k=min(5, num_classes), labels=range(num_classes)) * 100
    return loss_avg, top1_acc, top5_acc

def run_privacy_attack_unified(*args, **kwargs):
    return run_privacy_attack_vision_multimode(*args, **kwargs)

def run_privacy_attack_vision_multimode(vfl_model, train_loader, test_loader, config, num_classes, attack_type="passive", aux_batches_limit=4, attacker_lr=5e-4, attacker_epochs=2, active_eps=0.02, active_queries=2, perturbed_sigma=0.1):
    device = config["DEVICE"]
    vfl_model.eval()
    (xa_ex, xp_ex), _ = next(iter(test_loader))
    with torch.no_grad():
        h_p_dim = vfl_model.bottom_p(xp_ex[:1].to(device)).shape[1]
    feat_dim = h_p_dim * active_queries if attack_type == "active" else h_p_dim
    attacker = nn.Sequential(nn.Linear(feat_dim, 128), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(128, num_classes)).to(device)
    opt = torch.optim.AdamW(attacker.parameters(), lr=attacker_lr, weight_decay=1e-4)
    aux_data = []
    for i, batch in enumerate(train_loader):
        if i >= aux_batches_limit: break
        aux_data.append(batch)
    attacker.train()
    for _ in range(attacker_epochs):
        for (xa, xp), y in aux_data:
            xa, xp, y = xa.to(device), xp.to(device), y.to(device)
            with torch.enable_grad():
                if attack_type == "direct":
                    feat = vfl_model.bottom_p(xp)
                elif attack_type == "active":
                    grads = []
                    h_a = vfl_model.bottom_a(xa)
                    for _ in range(active_queries):
                        delta = (torch.rand_like(xp) * 2 - 1) * active_eps
                        xp_q = torch.clamp(xp + delta, config["CLAMP_MIN"], config["CLAMP_MAX"])
                        h_p_q = vfl_model.bottom_p(xp_q); h_p_q.requires_grad_()
                        logits_q = vfl_model.top_model(h_a, h_p_q)
                        grad_q = torch.autograd.grad(logits_q.sum(), h_p_q, retain_graph=False)[0]
                        grads.append(grad_q)
                    feat = torch.cat(grads, dim=1)
                else:
                    h_p = vfl_model.bottom_p(xp); h_p.requires_grad_()
                    h_a = vfl_model.bottom_a(xa)
                    logits = vfl_model.top_model(h_a, h_p)
                    grad = torch.autograd.grad(logits.sum(), h_p, retain_graph=False)[0]
                    if attack_type == "passive": feat = grad
                    elif attack_type == "perturbed": feat = grad + torch.randn_like(grad) * perturbed_sigma
            opt.zero_grad(set_to_none=True)
            pred = attacker(feat.detach())
            loss = F.cross_entropy(pred, y)
            loss.backward()
            opt.step()
    attacker.eval()
    y_true, y_probs = [], []
    with torch.no_grad():
        for (xa, xp), y in test_loader:
            xa, xp, y = xa.to(device), xp.to(device), y.to(device)
            with torch.enable_grad():
                if attack_type == "direct": feat = vfl_model.bottom_p(xp)
                elif attack_type == "active":
                    grads, h_a = [], vfl_model.bottom_a(xa)
                    for _ in range(active_queries):
                        delta = (torch.rand_like(xp) * 2 - 1) * active_eps
                        xp_q = torch.clamp(xp + delta, config["CLAMP_MIN"], config["CLAMP_MAX"])
                        h_p_q = vfl_model.bottom_p(xp_q); h_p_q.requires_grad_()
                        logits_q = vfl_model.top_model(h_a, h_p_q)
                        grad_q = torch.autograd.grad(logits_q.sum(), h_p_q, retain_graph=False)[0]
                        grads.append(grad_q)
                    feat = torch.cat(grads, dim=1)
                else:
                    h_p = vfl_model.bottom_p(xp); h_p.requires_grad_()
                    h_a = vfl_model.bottom_a(xa)
                    logits = vfl_model.top_model(h_a, h_p)
                    grad = torch.autograd.grad(logits.sum(), h_p, retain_graph=False)[0]
                    if attack_type == "passive": feat = grad
                    elif attack_type == "perturbed": feat = grad
            probs = F.softmax(attacker(feat.detach()), dim=1).cpu()
            y_true.extend(y.cpu().tolist()); y_probs.extend(probs.tolist())
    y_true, y_probs = np.array(y_true), np.array(y_probs)
    if num_classes == 2:
        preds = np.argmax(y_probs, axis=1)
        asr_top1 = (preds == y_true).mean() * 100
        asr_top5 = asr_top1
    else:
        asr_top1 = top_k_accuracy_score(y_true, y_probs, k=1, labels=range(num_classes)) * 100
        asr_top5 = top_k_accuracy_score(y_true, y_probs, k=min(5, num_classes), labels=range(num_classes)) * 100
    return float(asr_top1), float(asr_top5)

print("Unified training, evaluation, and attack functions and metric reporting functions declared")

✓ Unified training, evaluation, and attack functions are ready (with correct metric reporting).


In [None]:
# Cell 3 – Checkpoint Saving Utility
from pathlib import Path
import torch

def save_checkpoint(model, dataset_name, model_type, epoch, config):
    """
    Saves a model checkpoint to a dedicated checkpoints folder in Google Drive.

    Args:
        model (nn.Module): The model to be saved.
        dataset_name (str): Name of the dataset (e.g., "CIFAR10").
        model_type (str): Type of the model (e.g., "OA", "KDk").
        epoch (int): The current epoch number (0-indexed).
        config (dict): The main configuration dictionary.
    """
    try:
        # Define the base results directory from your config
        # This needs to handle both VISION and TABULAR_TEXT paths
        is_vision = dataset_name in ["CIFAR10", "CIFAR100", "CINIC10"]
        folder_name = "VISION" if is_vision else "TABULAR_TEXT"

        base_dir = Path(config["DRIVE_PATH"]) / "VFL_Results" / folder_name

        # Create a dedicated subdirectory for checkpoints
        checkpoint_dir = base_dir / "checkpoints"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # Construct a descriptive filename
        checkpoint_filename = f"{dataset_name}_{model_type}_epoch_{epoch+1}.pth"
        checkpoint_path = checkpoint_dir / checkpoint_filename

        # Save the model's state dictionary
        torch.save(model.state_dict(), checkpoint_path)

        print(f"    → Checkpoint saved: {checkpoint_path.name}")

    except Exception as e:
        print(f"    ⚠️ Error saving checkpoint: {e}")
        print("    Please ensure your Google Drive is mounted and accessible.")

print("Checkpoint saving utility is ready")

✓ Cell 6a: Checkpoint saving utility is ready.


In [None]:
# Cell 4 – Text & Tabular Dataset Loading and Preprocessing
!pip install -q datasets sentence_transformers fastparquet

import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import zipfile
import shutil
from pathlib import Path

# VFL Dataset Class for Text/Tabular
class VFLTextTabularDataset(Dataset):
    def __init__(self, Xa, Xp, y):
        self.Xa = torch.tensor(Xa, dtype=torch.float32)
        self.Xp = torch.tensor(Xp, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return (self.Xa[idx], self.Xp[idx]), self.y[idx]

# Main Loader Dictionary
TABULAR_TEXT_LOADERS = {}
DATA_PATH = Path("/content/datasets")
DATA_PATH.mkdir(parents=True, exist_ok=True)

# Yahoo! Answers (Text)
print("Preparing Yahoo! Answers dataset...")
try:
    dataset = load_dataset("yahoo_answers_topics")
    train_df = dataset['train'].to_pandas().sample(n=CONFIG["SIZES"]["YAHOO"][0], random_state=CONFIG["SEED"])
    test_df = dataset['test'].to_pandas().sample(n=CONFIG["SIZES"]["YAHOO"][1], random_state=CONFIG["SEED"])
    print("  Generating text embeddings with SentenceTransformer...")
    encoder = SentenceTransformer('all-mpnet-base-v2', device=CONFIG["DEVICE"])
    X_train_title = encoder.encode(train_df['question_title'].tolist(), show_progress_bar=True)
    X_train_content = encoder.encode(train_df['question_content'].tolist(), show_progress_bar=True)
    X_test_title = encoder.encode(test_df['question_title'].tolist(), show_progress_bar=True)
    X_test_content = encoder.encode(test_df['question_content'].tolist(), show_progress_bar=True)
    y_train, y_test = train_df['topic'].values, test_df['topic'].values
    train_ds, test_ds = VFLTextTabularDataset(X_train_title, X_train_content, y_train), VFLTextTabularDataset(X_test_title, X_test_content, y_test)
    TABULAR_TEXT_LOADERS["YAHOO"] = {
        "train": DataLoader(train_ds, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, drop_last=True),
        "test": DataLoader(test_ds, batch_size=CONFIG["BATCH_SIZE"], shuffle=False),
        "num_classes": 10
    }
    print(f"  Yahoo! Answers loaded. Train={len(train_ds)}, Test={len(test_ds)}")
except Exception as e:
    print(f"  Could not load Yahoo! Answers dataset. Error: {e}")

# Criteo CTR (Tabular)
print("\nPreparing Criteo CTR dataset...")
criteo_zip = DATA_PATH / "criteo.zip"
criteo_path = DATA_PATH / "criteo"
drive_criteo_zip = Path("/content/drive/MyDrive/datasets/criteo.zip")

if drive_criteo_zip.exists() and not criteo_zip.exists():
    print("  Copying criteo.zip to local disk for speed...")
    shutil.copy(drive_criteo_zip, criteo_zip)

if criteo_zip.exists():
    if not criteo_path.exists():
        print(f"  Extracting Criteo from {criteo_zip}...")
        with zipfile.ZipFile(criteo_zip, 'r') as zf:
            zf.extractall(criteo_path)

    potential_nested_path = criteo_path / 'criteo'
    if potential_nested_path.exists() and potential_nested_path.is_dir():
        print("  Nested 'criteo' directory found. Adjusting path.")
        criteo_path = potential_nested_path

    try:
        train_df = pd.read_parquet(criteo_path / 'train.parquet').sample(n=CONFIG["SIZES"]["CRITEO"][0], random_state=CONFIG["SEED"])
        test_df = pd.read_parquet(criteo_path / 'test.parquet').sample(n=CONFIG["SIZES"]["CRITEO"][1], random_state=CONFIG["SEED"])

        dense_feature_cols = train_df.columns[1:14]

        scaler = StandardScaler()
        X_train_dense = scaler.fit_transform(train_df[dense_feature_cols].fillna(0))
        X_test_dense = scaler.transform(test_df[dense_feature_cols].fillna(0))

        y_train, y_test = train_df['label'].values, test_df['label'].values

        Xa_train, Xp_train = X_train_dense[:, :6], X_train_dense[:, 6:]
        Xa_test, Xp_test = X_test_dense[:, :6], X_test_dense[:, 6:]

        train_ds, test_ds = VFLTextTabularDataset(Xa_train, Xp_train, y_train), VFLTextTabularDataset(Xa_test, Xp_test, y_test)

        TABULAR_TEXT_LOADERS["CRITEO"] = {
            "train": DataLoader(train_ds, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, drop_last=True),
            "test": DataLoader(test_ds, batch_size=CONFIG["BATCH_SIZE"], shuffle=False),
            "num_classes": 2
        }
        print(f"  Criteo CTR loaded. Train={len(train_ds)}, Test={len(test_ds)}")

    except Exception as e:
        print(f"  Could not process Criteo dataset. Error: {e}")
else:
    print("  Criteo dataset zip not found. Skipped.")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m72.0 MB/s[0m eta [36m0:00:00[0m
[?25hPreparing Yahoo! Answers dataset...


README.md: 0.00B [00:00, ?B/s]

yahoo_answers_topics/train-00000-of-0000(…):   0%|          | 0.00/241M [00:00<?, ?B/s]

yahoo_answers_topics/train-00001-of-0000(…):   0%|          | 0.00/270M [00:00<?, ?B/s]

yahoo_answers_topics/test-00000-of-00001(…):   0%|          | 0.00/21.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1400000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/60000 [00:00<?, ? examples/s]

  Generating text embeddings with SentenceTransformer...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/1563 [00:00<?, ?it/s]

Batches:   0%|          | 0/1563 [00:00<?, ?it/s]

Batches:   0%|          | 0/625 [00:00<?, ?it/s]

Batches:   0%|          | 0/625 [00:00<?, ?it/s]

  Yahoo! Answers loaded. Train=50000, Test=20000

Preparing Criteo CTR dataset...
  Copying criteo.zip to local disk for speed...
  Extracting Criteo from /content/datasets/criteo.zip...
  Nested 'criteo' directory found. Adjusting path.
  Criteo CTR loaded. Train=80000, Test=20000


In [None]:
# Cell 5 – VFL Text/Tabular Top and Bottom Model Definitions
class BottomModel_TabularText(nn.Module):
    def __init__(self, in_features, layers=[512, 256]):
        super().__init__()
        net = []
        if layers:
            for i, out_features in enumerate(layers):
                net.append(nn.Linear(in_features, out_features))
                net.append(nn.BatchNorm1d(out_features))
                net.append(nn.ReLU(inplace=True))
                net.append(nn.Dropout(0.3))
                in_features = out_features
        else:
            net.append(nn.Linear(in_features, in_features))
            net.append(nn.ReLU(inplace=True))

        self.net = nn.Sequential(*net)

    def forward(self, x):
        return self.net(x)

class TopModel_TabularText(nn.Module):
    def __init__(self, in_features, num_classes, layers=[512, 256]):
        super().__init__()
        net = []
        for i, out_features in enumerate(layers):
            net.append(nn.Linear(in_features, out_features))
            net.append(nn.BatchNorm1d(out_features))
            net.append(nn.ReLU(inplace=True))
            in_features = out_features
        self.net = nn.Sequential(*net)
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, h_a, h_p):
        h = torch.cat([h_a, h_p], dim=1)
        features = self.net(h)
        return self.classifier(features)

class VFLModel_TT(nn.Module):
    def __init__(self, bottom_a, bottom_p, top_model):
        super().__init__()
        self.bottom_a = bottom_a
        self.bottom_p = bottom_p
        self.top_model = top_model

    def forward(self, x_a, x_p):
        return self.top_model(self.bottom_a(x_a), self.bottom_p(x_p))

def build_vfl_model_tabular_text(dataset_name, num_classes):
    if dataset_name == "YAHOO":
        bottom_a = BottomModel_TabularText(in_features=768, layers=[256]) # Single hidden layer
        bottom_p = BottomModel_TabularText(in_features=768, layers=[256]) # Single hidden layer
        top_model = TopModel_TabularText(in_features=512, num_classes=num_classes, layers=[256, 128])

    elif dataset_name == "CRITEO":
        bottom_a = BottomModel_TabularText(in_features=6, layers=[]) # No hidden layers
        bottom_p = BottomModel_TabularText(in_features=7, layers=[]) # No hidden layers
        top_model = TopModel_TabularText(in_features=13, num_classes=num_classes, layers=[32, 16])

    else:
        raise ValueError(f"Unknown dataset name for model building: {dataset_name}")

    model = VFLModel_TT(bottom_a, bottom_p, top_model)
    return model.to(CONFIG["DEVICE"])

print("Text/Tabular model architectures defined")

✓ Cell 10: Optimized Text/Tabular model architectures are ready.


In [None]:
# Cell 6 – Main Text/Tabular Training Loop
import warnings
import time
from pathlib import Path
import torch
import torch.optim as optim

warnings.filterwarnings("ignore", category=UserWarning)

TABTEXT_GROUP = ["YAHOO", "CRITEO"]
SAVED_MODELS_TABTEXT = {}

drive_results_dir_tt = Path(CONFIG["DRIVE_PATH"]) / "VFL_Results" / "TABULAR_TEXT"
drive_results_dir_tt.mkdir(parents=True, exist_ok=True)
print("Starting Text/Tabular VFL Training with Final Optimizations\n" + "="*70)

for dataset_name in TABTEXT_GROUP:
    # Setup
    loaders = TABULAR_TEXT_LOADERS[dataset_name]
    tr_loader, te_loader = loaders["train"], loaders["test"]
    num_classes = loaders["num_classes"]
    epochs = CONFIG["EPOCHS_PER_DATASET"].get(dataset_name, CONFIG["DEFAULT_EPOCHS"])
    hp = CONFIG["DATASET_HP"][dataset_name]
    SAVED_MODELS_TABTEXT[dataset_name] = {}
    patience = 3
    builder_fn = build_vfl_model_tabular_text

    # Phase 1: Original Architecture OA/Teacher
    print(f"\n{dataset_name} – Phase 1: Training OA (Teacher)")
    OA = builder_fn(dataset_name, num_classes)
    opt = optim.AdamW(OA.parameters(), lr=hp["lr_teacher"])
    best_val, no_imp = 0, 0
    for ep in range(epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(OA, tr_loader, opt, CONFIG, dataset_name)
        val_loss, val_acc, _ = evaluate_model_unified(OA, te_loader, CONFIG, dataset_name)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val Loss={val_loss:.4f} | Val ACC={val_acc:6.2f}% | Time={elapsed:.1f}s")

        if val_acc > best_val:
            best_val, no_imp = val_acc, 0
            torch.save(OA.state_dict(), drive_results_dir_tt / f"{dataset_name}_OA_best.pth")
            print(f"    → Best model saved (Val ACC: {val_acc:.2f}%)")
        elif ep > 1:
            no_imp += 1
        if no_imp >= patience:
            print("  Early stopping triggered.")
            break
    OA.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_OA_best.pth", map_location=CONFIG["DEVICE"]))
    SAVED_MODELS_TABTEXT[dataset_name]["OA"] = f"{dataset_name}_OA_best.pth"

    # Phase 2: Knowledge Distillation (KD) with k-anonymity
    print(f"\n{dataset_name} – Phase 2: Training KDk")
    KDk = builder_fn(dataset_name, num_classes)
    opt = optim.AdamW(KDk.parameters(), lr=hp["lr_student"])
    best_val, no_imp = 0, 0
    for ep in range(epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(KDk, tr_loader, opt, CONFIG, dataset_name, teacher=OA)
        val_loss, val_acc, _ = evaluate_model_unified(KDk, te_loader, CONFIG, dataset_name)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val Loss={val_loss:.4f} | Val ACC={val_acc:6.2f}% | Time={elapsed:.1f}s")

        if val_acc > best_val:
            best_val, no_imp = val_acc, 0
            torch.save(KDk.state_dict(), drive_results_dir_tt / f"{dataset_name}_KDk_best.pth")
            print(f"    → Best model saved (Val ACC: {val_acc:.2f}%)")
        elif ep > 1:
            no_imp += 1
        if no_imp >= patience:
            print("  Early stopping triggered.")
            break
    KDk.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_KDk_best.pth", map_location=CONFIG["DEVICE"]))
    SAVED_MODELS_TABTEXT[dataset_name]["KDk"] = f"{dataset_name}_KDk_best.pth"

    # Phase 3: KDk with Adversarial Training (KDk+AT)
    print(f"\n{dataset_name} – Phase 3: Training KDk+AT")
    KDkAT = builder_fn(dataset_name, num_classes)
    opt = optim.AdamW(KDkAT.parameters(), lr=hp["lr_student"])
    best_rob, no_imp = 0, 0
    for ep in range(epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(KDkAT, tr_loader, opt, CONFIG, dataset_name, teacher=OA, use_at=True)
        val_loss, val_acc, _ = evaluate_model_unified(KDkAT, te_loader, CONFIG, dataset_name)
        _, rob_acc, _ = evaluate_model_unified(KDkAT, te_loader, CONFIG, dataset_name, robust=True)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val ACC={val_acc:6.2f}% | Robust={rob_acc:6.2f}% | Time={elapsed:.1f}s")

        if rob_acc > best_rob:
            best_rob, no_imp = rob_acc, 0
            torch.save(KDkAT.state_dict(), drive_results_dir_tt / f"{dataset_name}_KDkAT_best.pth")
            print(f"    → Best model saved (Robust ACC: {rob_acc:.2f}%)")
        elif ep > 1:
            no_imp += 1
        if no_imp >= patience:
            print("  Early stopping triggered.")
            break
    KDkAT.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_KDkAT_best.pth", map_location=CONFIG["DEVICE"]))
    SAVED_MODELS_TABTEXT[dataset_name]["KDk+AT"] = f"{dataset_name}_KDkAT_best.pth"

    # Phase 4: KDk+AT+DP (KDk with Adversarial Training and Differential Privacy)
    print(f"\n{dataset_name} – Phase 4: Training KDk+AT+DP")
    KDkATDP = builder_fn(dataset_name, num_classes)
    KDkATDP.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_KDkAT_best.pth", map_location=CONFIG["DEVICE"]))

    if dataset_name == "YAHOO":
        finetune_lr = hp["lr_student"] / 25
        grad_noise = 4.0
        embed_noise = 0.3
        finetune_epochs = 20
        use_embed_noise_flag = True
    else: # Criteo
        finetune_lr = hp["lr_student"] / 10
        grad_noise = 1.2
        embed_noise = 0.1
        finetune_epochs = 15
        use_embed_noise_flag = False

    opt = optim.AdamW(KDkATDP.parameters(), lr=finetune_lr)
    print(f"  Applying DP with grad_noise: {grad_noise}, embed_noise: {embed_noise}, LR: {finetune_lr:.2e}, Epochs: {finetune_epochs}")

    for ep in range(finetune_epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(KDkATDP, tr_loader, opt, CONFIG, dataset_name, teacher=OA, use_at=True, use_dp=True,
                                                      noise_multiplier=grad_noise,
                                                      use_embedding_noise=use_embed_noise_flag,
                                                      embedding_noise_level=embed_noise)
        val_loss, val_acc, _ = evaluate_model_unified(KDkATDP, te_loader, CONFIG, dataset_name)
        _, rob_acc, _ = evaluate_model_unified(KDkATDP, te_loader, CONFIG, dataset_name, robust=True)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{finetune_epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val ACC={val_acc:6.2f}% | Robust={rob_acc:6.2f}% | Time={elapsed:.1f}s")

    torch.save(KDkATDP.state_dict(), drive_results_dir_tt / f"{dataset_name}_KDkATDP_best.pth")
    SAVED_MODELS_TABTEXT[dataset_name]["KDk+AT+DP"] = f"{dataset_name}_KDkATDP_best.pth"

    del OA, KDk, KDkAT, KDkATDP, opt
    torch.cuda.empty_cache()

print("\n" + "="*70 + "\nTraining completed for all Text/Tabular datasets!")

Starting Text/Tabular VFL Training with Final Optimizations

YAHOO – Phase 1: Training OA (Teacher)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 117.55it/s, loss=1.0331]


  Epoch 01/10 | Train Loss=0.9953 | Train ACC= 69.14% | Val Loss=0.8635 | Val ACC= 72.01% | Time=7.6s
    → Best model saved (Val ACC: 72.01%)


Training YAHOO: 100%|██████████████████████████████████████████| 781/781 [00:09<00:00, 83.63it/s, loss=0.8523]


  Epoch 02/10 | Train Loss=0.8209 | Train ACC= 73.36% | Val Loss=0.8518 | Val ACC= 72.27% | Time=10.3s
    → Best model saved (Val ACC: 72.27%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 140.99it/s, loss=0.6749]


  Epoch 03/10 | Train Loss=0.7701 | Train ACC= 74.76% | Val Loss=0.8482 | Val ACC= 72.44% | Time=6.2s
    → Best model saved (Val ACC: 72.44%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 170.61it/s, loss=0.8690]


  Epoch 04/10 | Train Loss=0.7297 | Train ACC= 75.74% | Val Loss=0.8532 | Val ACC= 72.52% | Time=5.1s
    → Best model saved (Val ACC: 72.52%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 189.03it/s, loss=0.7041]


  Epoch 05/10 | Train Loss=0.6911 | Train ACC= 76.89% | Val Loss=0.8687 | Val ACC= 72.50% | Time=4.7s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 167.62it/s, loss=0.7772]


  Epoch 06/10 | Train Loss=0.6570 | Train ACC= 77.71% | Val Loss=0.8754 | Val ACC= 72.46% | Time=5.4s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 190.85it/s, loss=0.6979]


  Epoch 07/10 | Train Loss=0.6253 | Train ACC= 78.76% | Val Loss=0.8922 | Val ACC= 72.42% | Time=4.6s
  Early stopping triggered.

YAHOO – Phase 2: Training KDk


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 158.20it/s, loss=0.7651]


  Epoch 01/10 | Train Loss=1.1906 | Train ACC= 64.52% | Val Loss=0.8977 | Val ACC= 71.43% | Time=5.5s
    → Best model saved (Val ACC: 71.43%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 139.49it/s, loss=0.4514]


  Epoch 02/10 | Train Loss=0.4954 | Train ACC= 72.66% | Val Loss=0.8718 | Val ACC= 72.16% | Time=6.1s
    → Best model saved (Val ACC: 72.16%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 158.52it/s, loss=0.3996]


  Epoch 03/10 | Train Loss=0.4109 | Train ACC= 73.83% | Val Loss=0.8622 | Val ACC= 72.52% | Time=5.5s
    → Best model saved (Val ACC: 72.52%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 138.99it/s, loss=0.3643]


  Epoch 04/10 | Train Loss=0.3779 | Train ACC= 74.69% | Val Loss=0.8571 | Val ACC= 72.52% | Time=6.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 160.54it/s, loss=0.4034]


  Epoch 05/10 | Train Loss=0.3541 | Train ACC= 75.21% | Val Loss=0.8588 | Val ACC= 72.47% | Time=5.4s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 152.17it/s, loss=0.2937]


  Epoch 06/10 | Train Loss=0.3395 | Train ACC= 75.69% | Val Loss=0.8556 | Val ACC= 72.70% | Time=5.9s
    → Best model saved (Val ACC: 72.70%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 149.11it/s, loss=0.2974]


  Epoch 07/10 | Train Loss=0.3302 | Train ACC= 76.38% | Val Loss=0.8519 | Val ACC= 72.73% | Time=5.8s
    → Best model saved (Val ACC: 72.73%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 157.17it/s, loss=0.3276]


  Epoch 08/10 | Train Loss=0.3214 | Train ACC= 76.62% | Val Loss=0.8515 | Val ACC= 72.69% | Time=5.5s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 137.03it/s, loss=0.2929]


  Epoch 09/10 | Train Loss=0.3155 | Train ACC= 76.94% | Val Loss=0.8498 | Val ACC= 72.71% | Time=6.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 158.54it/s, loss=0.3602]


  Epoch 10/10 | Train Loss=0.3067 | Train ACC= 77.34% | Val Loss=0.8500 | Val ACC= 72.76% | Time=5.4s
    → Best model saved (Val ACC: 72.76%)

YAHOO – Phase 3: Training KDk+AT


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 109.65it/s, loss=0.5498]


  Epoch 01/10 | Train Loss=1.2847 | Train ACC= 59.48% | Val ACC= 71.54% | Robust= 68.43% | Time=9.2s
    → Best model saved (Robust ACC: 68.43%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 122.94it/s, loss=0.3946]


  Epoch 02/10 | Train Loss=0.5405 | Train ACC= 69.69% | Val ACC= 71.95% | Robust= 69.11% | Time=8.4s
    → Best model saved (Robust ACC: 69.11%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 119.09it/s, loss=0.5027]


  Epoch 03/10 | Train Loss=0.4482 | Train ACC= 70.86% | Val ACC= 72.46% | Robust= 69.42% | Time=9.0s
    → Best model saved (Robust ACC: 69.42%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 109.98it/s, loss=0.4380]


  Epoch 04/10 | Train Loss=0.4119 | Train ACC= 71.58% | Val ACC= 72.59% | Robust= 69.69% | Time=8.7s
    → Best model saved (Robust ACC: 69.69%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 124.08it/s, loss=0.3712]


  Epoch 05/10 | Train Loss=0.3912 | Train ACC= 72.27% | Val ACC= 72.61% | Robust= 69.73% | Time=8.2s
    → Best model saved (Robust ACC: 69.73%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 115.28it/s, loss=0.3329]


  Epoch 06/10 | Train Loss=0.3743 | Train ACC= 72.82% | Val ACC= 72.61% | Robust= 69.84% | Time=8.4s
    → Best model saved (Robust ACC: 69.84%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 110.90it/s, loss=0.4034]


  Epoch 07/10 | Train Loss=0.3647 | Train ACC= 73.05% | Val ACC= 72.81% | Robust= 70.02% | Time=8.7s
    → Best model saved (Robust ACC: 70.02%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 123.76it/s, loss=0.3403]


  Epoch 08/10 | Train Loss=0.3563 | Train ACC= 73.47% | Val ACC= 72.77% | Robust= 70.09% | Time=7.9s
    → Best model saved (Robust ACC: 70.09%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 109.67it/s, loss=0.4104]


  Epoch 09/10 | Train Loss=0.3484 | Train ACC= 73.57% | Val ACC= 72.81% | Robust= 69.95% | Time=8.7s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 118.10it/s, loss=0.3312]


  Epoch 10/10 | Train Loss=0.3432 | Train ACC= 73.97% | Val ACC= 72.61% | Robust= 69.93% | Time=8.8s

YAHOO – Phase 4: Training KDk+AT+DP
  Applying DP with grad_noise: 4.0, embed_noise: 0.3, LR: 4.00e-06, Epochs: 20


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 120.47it/s, loss=0.3235]


  Epoch 01/20 | Train Loss=0.3480 | Train ACC= 73.81% | Val ACC= 72.69% | Robust= 70.17% | Time=8.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 107.44it/s, loss=0.4418]


  Epoch 02/20 | Train Loss=0.3471 | Train ACC= 74.07% | Val ACC= 72.65% | Robust= 69.92% | Time=8.9s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 115.16it/s, loss=0.2940]


  Epoch 03/20 | Train Loss=0.3456 | Train ACC= 74.08% | Val ACC= 72.72% | Robust= 70.01% | Time=8.9s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 113.48it/s, loss=0.3416]


  Epoch 04/20 | Train Loss=0.3464 | Train ACC= 74.04% | Val ACC= 72.75% | Robust= 69.92% | Time=8.5s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 105.06it/s, loss=0.3353]


  Epoch 05/20 | Train Loss=0.3439 | Train ACC= 74.03% | Val ACC= 72.76% | Robust= 70.10% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 117.93it/s, loss=0.2591]


  Epoch 06/20 | Train Loss=0.3442 | Train ACC= 74.13% | Val ACC= 72.71% | Robust= 70.08% | Time=8.7s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 113.24it/s, loss=0.3195]


  Epoch 07/20 | Train Loss=0.3448 | Train ACC= 74.01% | Val ACC= 72.71% | Robust= 69.92% | Time=8.5s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 104.89it/s, loss=0.3750]


  Epoch 08/20 | Train Loss=0.3435 | Train ACC= 74.31% | Val ACC= 72.65% | Robust= 69.99% | Time=9.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 116.48it/s, loss=0.2980]


  Epoch 09/20 | Train Loss=0.3436 | Train ACC= 74.12% | Val ACC= 72.70% | Robust= 70.06% | Time=8.6s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 107.77it/s, loss=0.4131]


  Epoch 10/20 | Train Loss=0.3445 | Train ACC= 74.01% | Val ACC= 72.73% | Robust= 70.02% | Time=8.8s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 103.62it/s, loss=0.3449]


  Epoch 11/20 | Train Loss=0.3425 | Train ACC= 74.07% | Val ACC= 72.70% | Robust= 70.03% | Time=9.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 117.32it/s, loss=0.3875]


  Epoch 12/20 | Train Loss=0.3430 | Train ACC= 74.14% | Val ACC= 72.65% | Robust= 70.01% | Time=8.5s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 107.06it/s, loss=0.2625]


  Epoch 13/20 | Train Loss=0.3420 | Train ACC= 74.05% | Val ACC= 72.76% | Robust= 70.04% | Time=8.9s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 107.26it/s, loss=0.3474]


  Epoch 14/20 | Train Loss=0.3406 | Train ACC= 74.31% | Val ACC= 72.61% | Robust= 69.89% | Time=8.9s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 118.98it/s, loss=0.2909]


  Epoch 15/20 | Train Loss=0.3412 | Train ACC= 74.19% | Val ACC= 72.66% | Robust= 69.98% | Time=8.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 106.33it/s, loss=0.3335]


  Epoch 16/20 | Train Loss=0.3413 | Train ACC= 74.30% | Val ACC= 72.76% | Robust= 70.05% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 110.13it/s, loss=0.3597]


  Epoch 17/20 | Train Loss=0.3422 | Train ACC= 74.07% | Val ACC= 72.61% | Robust= 70.08% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 116.26it/s, loss=0.3401]


  Epoch 18/20 | Train Loss=0.3411 | Train ACC= 74.31% | Val ACC= 72.78% | Robust= 69.98% | Time=8.3s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 103.98it/s, loss=0.2522]


  Epoch 19/20 | Train Loss=0.3394 | Train ACC= 74.30% | Val ACC= 72.68% | Robust= 69.98% | Time=9.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 109.62it/s, loss=0.3348]


  Epoch 20/20 | Train Loss=0.3405 | Train ACC= 74.29% | Val ACC= 72.78% | Robust= 70.01% | Time=9.2s

CRITEO – Phase 1: Training OA (Teacher)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 203.31it/s, loss=0.5645]


  Epoch 01/15 | Train Loss=0.5333 | Train ACC= 74.56% | Val Loss=0.5210 | Val ACC= 75.08% | Time=6.6s
    → Best model saved (Val ACC: 75.08%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 186.55it/s, loss=0.4990]


  Epoch 02/15 | Train Loss=0.5213 | Train ACC= 75.23% | Val Loss=0.5190 | Val ACC= 75.50% | Time=7.3s
    → Best model saved (Val ACC: 75.50%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 201.47it/s, loss=0.5345]


  Epoch 03/15 | Train Loss=0.5187 | Train ACC= 75.32% | Val Loss=0.5168 | Val ACC= 75.48% | Time=6.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:12<00:00, 103.03it/s, loss=0.6148]


  Epoch 04/15 | Train Loss=0.5178 | Train ACC= 75.47% | Val Loss=0.5155 | Val ACC= 75.52% | Time=12.9s
    → Best model saved (Val ACC: 75.52%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 141.42it/s, loss=0.4844]


  Epoch 05/15 | Train Loss=0.5168 | Train ACC= 75.53% | Val Loss=0.5163 | Val ACC= 75.43% | Time=9.4s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 201.68it/s, loss=0.6540]


  Epoch 06/15 | Train Loss=0.5159 | Train ACC= 75.77% | Val Loss=0.5169 | Val ACC= 75.61% | Time=6.6s
    → Best model saved (Val ACC: 75.61%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 193.68it/s, loss=0.4783]


  Epoch 07/15 | Train Loss=0.5154 | Train ACC= 75.77% | Val Loss=0.5143 | Val ACC= 75.81% | Time=7.0s
    → Best model saved (Val ACC: 75.81%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 198.50it/s, loss=0.4434]


  Epoch 08/15 | Train Loss=0.5150 | Train ACC= 75.75% | Val Loss=0.5148 | Val ACC= 75.64% | Time=6.7s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 195.64it/s, loss=0.5176]


  Epoch 09/15 | Train Loss=0.5140 | Train ACC= 75.83% | Val Loss=0.5148 | Val ACC= 75.84% | Time=7.0s
    → Best model saved (Val ACC: 75.84%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 197.01it/s, loss=0.4544]


  Epoch 10/15 | Train Loss=0.5137 | Train ACC= 75.78% | Val Loss=0.5236 | Val ACC= 75.08% | Time=6.8s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 203.98it/s, loss=0.5043]


  Epoch 11/15 | Train Loss=0.5139 | Train ACC= 75.81% | Val Loss=0.5171 | Val ACC= 75.61% | Time=6.7s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 190.85it/s, loss=0.6279]


  Epoch 12/15 | Train Loss=0.5136 | Train ACC= 75.82% | Val Loss=0.5180 | Val ACC= 75.05% | Time=7.0s
  Early stopping triggered.

CRITEO – Phase 2: Training KDk


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 167.82it/s, loss=0.1485]


  Epoch 01/15 | Train Loss=0.1900 | Train ACC= 73.12% | Val Loss=0.5225 | Val ACC= 74.97% | Time=8.0s
    → Best model saved (Val ACC: 74.97%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 141.77it/s, loss=0.1633]


  Epoch 02/15 | Train Loss=0.1613 | Train ACC= 75.15% | Val Loss=0.5189 | Val ACC= 74.98% | Time=9.2s
    → Best model saved (Val ACC: 74.98%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 156.02it/s, loss=0.1880]


  Epoch 03/15 | Train Loss=0.1597 | Train ACC= 75.32% | Val Loss=0.5186 | Val ACC= 75.09% | Time=8.4s
    → Best model saved (Val ACC: 75.09%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 173.51it/s, loss=0.1271]


  Epoch 04/15 | Train Loss=0.1585 | Train ACC= 75.49% | Val Loss=0.5172 | Val ACC= 75.23% | Time=7.6s
    → Best model saved (Val ACC: 75.23%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 155.94it/s, loss=0.1933]


  Epoch 05/15 | Train Loss=0.1578 | Train ACC= 75.68% | Val Loss=0.5169 | Val ACC= 75.28% | Time=8.4s
    → Best model saved (Val ACC: 75.28%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 170.63it/s, loss=0.1247]


  Epoch 06/15 | Train Loss=0.1576 | Train ACC= 75.74% | Val Loss=0.5159 | Val ACC= 75.42% | Time=7.9s
    → Best model saved (Val ACC: 75.42%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 159.69it/s, loss=0.1465]


  Epoch 07/15 | Train Loss=0.1572 | Train ACC= 75.81% | Val Loss=0.5165 | Val ACC= 75.67% | Time=8.2s
    → Best model saved (Val ACC: 75.67%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 161.54it/s, loss=0.1598]


  Epoch 08/15 | Train Loss=0.1569 | Train ACC= 75.85% | Val Loss=0.5156 | Val ACC= 75.81% | Time=8.3s
    → Best model saved (Val ACC: 75.81%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 169.01it/s, loss=0.1512]


  Epoch 09/15 | Train Loss=0.1567 | Train ACC= 75.91% | Val Loss=0.5170 | Val ACC= 75.37% | Time=7.8s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 154.52it/s, loss=0.1588]


  Epoch 10/15 | Train Loss=0.1566 | Train ACC= 75.80% | Val Loss=0.5149 | Val ACC= 75.78% | Time=8.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 171.22it/s, loss=0.1577]


  Epoch 11/15 | Train Loss=0.1564 | Train ACC= 75.91% | Val Loss=0.5191 | Val ACC= 75.09% | Time=7.7s
  Early stopping triggered.

CRITEO – Phase 3: Training KDk+AT


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 126.21it/s, loss=0.1537]


  Epoch 01/15 | Train Loss=0.1840 | Train ACC= 73.25% | Val ACC= 74.86% | Robust= 74.52% | Time=11.2s
    → Best model saved (Robust ACC: 74.52%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 126.52it/s, loss=0.1755]


  Epoch 02/15 | Train Loss=0.1698 | Train ACC= 74.43% | Val ACC= 74.91% | Robust= 74.41% | Time=11.2s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 129.91it/s, loss=0.2228]


  Epoch 03/15 | Train Loss=0.1684 | Train ACC= 74.51% | Val ACC= 75.24% | Robust= 74.44% | Time=11.2s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 137.98it/s, loss=0.1742]


  Epoch 04/15 | Train Loss=0.1675 | Train ACC= 74.60% | Val ACC= 74.92% | Robust= 74.37% | Time=10.7s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 132.11it/s, loss=0.1624]


  Epoch 05/15 | Train Loss=0.1670 | Train ACC= 74.72% | Val ACC= 74.95% | Robust= 74.47% | Time=10.7s
  Early stopping triggered.

CRITEO – Phase 4: Training KDk+AT+DP
  Applying DP with grad_noise: 1.2, embed_noise: 0.1, LR: 5.00e-05, Epochs: 15


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 125.49it/s, loss=0.1772]


  Epoch 01/15 | Train Loss=0.1703 | Train ACC= 74.52% | Val ACC= 74.95% | Robust= 74.47% | Time=11.2s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 124.84it/s, loss=0.1802]


  Epoch 02/15 | Train Loss=0.1701 | Train ACC= 74.50% | Val ACC= 74.58% | Robust= 74.39% | Time=11.3s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 128.99it/s, loss=0.1794]


  Epoch 03/15 | Train Loss=0.1698 | Train ACC= 74.52% | Val ACC= 75.03% | Robust= 74.37% | Time=11.3s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 135.67it/s, loss=0.1821]


  Epoch 04/15 | Train Loss=0.1696 | Train ACC= 74.54% | Val ACC= 74.89% | Robust= 74.48% | Time=10.8s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 129.74it/s, loss=0.1493]


  Epoch 05/15 | Train Loss=0.1695 | Train ACC= 74.45% | Val ACC= 74.91% | Robust= 74.43% | Time=10.9s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 125.27it/s, loss=0.1719]


  Epoch 06/15 | Train Loss=0.1694 | Train ACC= 74.51% | Val ACC= 74.89% | Robust= 74.45% | Time=11.2s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 124.21it/s, loss=0.1513]


  Epoch 07/15 | Train Loss=0.1691 | Train ACC= 74.54% | Val ACC= 75.03% | Robust= 74.44% | Time=11.3s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 128.37it/s, loss=0.1592]


  Epoch 08/15 | Train Loss=0.1690 | Train ACC= 74.61% | Val ACC= 75.09% | Robust= 74.48% | Time=11.4s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 131.18it/s, loss=0.1796]


  Epoch 09/15 | Train Loss=0.1690 | Train ACC= 74.51% | Val ACC= 74.92% | Robust= 74.45% | Time=11.3s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 126.57it/s, loss=0.1806]


  Epoch 10/15 | Train Loss=0.1688 | Train ACC= 74.61% | Val ACC= 74.97% | Robust= 74.37% | Time=11.2s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 122.58it/s, loss=0.1812]


  Epoch 11/15 | Train Loss=0.1686 | Train ACC= 74.61% | Val ACC= 74.92% | Robust= 74.44% | Time=11.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 122.68it/s, loss=0.1590]


  Epoch 12/15 | Train Loss=0.1686 | Train ACC= 74.54% | Val ACC= 74.56% | Robust= 74.40% | Time=11.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 121.27it/s, loss=0.1962]


  Epoch 13/15 | Train Loss=0.1685 | Train ACC= 74.58% | Val ACC= 74.95% | Robust= 74.41% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 127.29it/s, loss=0.1747]


  Epoch 14/15 | Train Loss=0.1683 | Train ACC= 74.52% | Val ACC= 74.81% | Robust= 74.48% | Time=11.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 129.77it/s, loss=0.1610]


  Epoch 15/15 | Train Loss=0.1682 | Train ACC= 74.58% | Val ACC= 74.84% | Robust= 74.47% | Time=11.1s

Training completed for all Text/Tabular datasets!


In [None]:
# @title
# Cell 11 – Main Text/Tabular Training Loop
import warnings
import time
from pathlib import Path
import torch
import torch.optim as optim

warnings.filterwarnings("ignore", category=UserWarning)

TABTEXT_GROUP = ["YAHOO", "CRITEO"]
SAVED_MODELS_TABTEXT = {}

drive_results_dir_tt = Path(CONFIG["DRIVE_PATH"]) / "VFL_Results" / "TABULAR_TEXT"
drive_results_dir_tt.mkdir(parents=True, exist_ok=True)
print("Starting Text/Tabular VFL Training with Final Optimizations\n" + "="*70)

for dataset_name in TABTEXT_GROUP:
    # Setup
    loaders = TABULAR_TEXT_LOADERS[dataset_name]
    tr_loader, te_loader = loaders["train"], loaders["test"]
    num_classes = loaders["num_classes"]
    epochs = CONFIG["EPOCHS_PER_DATASET"].get(dataset_name, CONFIG["DEFAULT_EPOCHS"])
    hp = CONFIG["DATASET_HP"][dataset_name]
    SAVED_MODELS_TABTEXT[dataset_name] = {}
    patience = 3
    builder_fn = build_vfl_model_tabular_text

    # ---- Phase 1: OA (Teacher) ----
    print(f"\n{dataset_name} – Phase 1: Training OA (Teacher)")
    OA = builder_fn(dataset_name, num_classes)
    opt = optim.AdamW(OA.parameters(), lr=hp["lr_teacher"])
    best_val, no_imp = 0, 0
    for ep in range(epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(OA, tr_loader, opt, CONFIG, dataset_name)
        val_loss, val_acc, _ = evaluate_model_unified(OA, te_loader, CONFIG, dataset_name)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val Loss={val_loss:.4f} | Val ACC={val_acc:6.2f}% | Time={elapsed:.1f}s")

        if val_acc > best_val:
            best_val, no_imp = val_acc, 0
            torch.save(OA.state_dict(), drive_results_dir_tt / f"{dataset_name}_OA_best.pth")
            print(f"    → Best model saved (Val ACC: {val_acc:.2f}%)")
        elif ep > 1:
            no_imp += 1
        if no_imp >= patience:
            print("  Early stopping triggered.")
            break
    OA.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_OA_best.pth", map_location=CONFIG["DEVICE"]))
    SAVED_MODELS_TABTEXT[dataset_name]["OA"] = f"{dataset_name}_OA_best.pth"

    # ---- Phase 2: KDk ----
    print(f"\n{dataset_name} – Phase 2: Training KDk")
    KDk = builder_fn(dataset_name, num_classes)
    opt = optim.AdamW(KDk.parameters(), lr=hp["lr_student"])
    best_val, no_imp = 0, 0
    for ep in range(epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(KDk, tr_loader, opt, CONFIG, dataset_name, teacher=OA)
        val_loss, val_acc, _ = evaluate_model_unified(KDk, te_loader, CONFIG, dataset_name)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val Loss={val_loss:.4f} | Val ACC={val_acc:6.2f}% | Time={elapsed:.1f}s")

        if val_acc > best_val:
            best_val, no_imp = val_acc, 0
            torch.save(KDk.state_dict(), drive_results_dir_tt / f"{dataset_name}_KDk_best.pth")
            print(f"    → Best model saved (Val ACC: {val_acc:.2f}%)")
        elif ep > 1:
            no_imp += 1
        if no_imp >= patience:
            print("  Early stopping triggered.")
            break
    KDk.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_KDk_best.pth", map_location=CONFIG["DEVICE"]))
    SAVED_MODELS_TABTEXT[dataset_name]["KDk"] = f"{dataset_name}_KDk_best.pth"

    # ---- Phase 3: KDk+AT ----
    print(f"\n{dataset_name} – Phase 3: Training KDk+AT")
    KDkAT = builder_fn(dataset_name, num_classes)
    opt = optim.AdamW(KDkAT.parameters(), lr=hp["lr_student"])
    best_rob, no_imp = 0, 0
    for ep in range(epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(KDkAT, tr_loader, opt, CONFIG, dataset_name, teacher=OA, use_at=True)
        val_loss, val_acc, _ = evaluate_model_unified(KDkAT, te_loader, CONFIG, dataset_name)
        _, rob_acc, _ = evaluate_model_unified(KDkAT, te_loader, CONFIG, dataset_name, robust=True)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val ACC={val_acc:6.2f}% | Robust={rob_acc:6.2f}% | Time={elapsed:.1f}s")

        if rob_acc > best_rob:
            best_rob, no_imp = rob_acc, 0
            torch.save(KDkAT.state_dict(), drive_results_dir_tt / f"{dataset_name}_KDkAT_best.pth")
            print(f"    → Best model saved (Robust ACC: {rob_acc:.2f}%)")
        elif ep > 1:
            no_imp += 1
        if no_imp >= patience:
            print("  Early stopping triggered.")
            break
    KDkAT.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_KDkAT_best.pth", map_location=CONFIG["DEVICE"]))
    SAVED_MODELS_TABTEXT[dataset_name]["KDk+AT"] = f"{dataset_name}_KDkAT_best.pth"

    # ---- Phase 4: KDk+AT+DP (with AGGRESSIVE HYPERPARAMETERS) ----
    print(f"\n{dataset_name} – Phase 4: Training KDk+AT+DP")
    KDkATDP = builder_fn(dataset_name, num_classes)
    KDkATDP.load_state_dict(torch.load(drive_results_dir_tt / f"{dataset_name}_KDkAT_best.pth", map_location=CONFIG["DEVICE"]))

    if dataset_name == "YAHOO":
        finetune_lr = hp["lr_student"] / 25
        grad_noise = 4.0
        embed_noise = 0.3
        finetune_epochs = 20
        use_embed_noise_flag = True
    else: # Criteo
        finetune_lr = hp["lr_student"] / 10
        grad_noise = 1.2
        embed_noise = 0.1
        finetune_epochs = 15
        use_embed_noise_flag = False

    opt = optim.AdamW(KDkATDP.parameters(), lr=finetune_lr)
    print(f"  Applying DP with grad_noise: {grad_noise}, embed_noise: {embed_noise}, LR: {finetune_lr:.2e}, Epochs: {finetune_epochs}")

    for ep in range(finetune_epochs):
        t0 = time.perf_counter()
        train_loss, train_acc = train_epoch_unified(KDkATDP, tr_loader, opt, CONFIG, dataset_name, teacher=OA, use_at=True, use_dp=True,
                                                      noise_multiplier=grad_noise,
                                                      use_embedding_noise=use_embed_noise_flag,
                                                      embedding_noise_level=embed_noise)
        val_loss, val_acc, _ = evaluate_model_unified(KDkATDP, te_loader, CONFIG, dataset_name)
        _, rob_acc, _ = evaluate_model_unified(KDkATDP, te_loader, CONFIG, dataset_name, robust=True)
        elapsed = time.perf_counter() - t0
        print(f"  Epoch {ep+1:02d}/{finetune_epochs} | Train Loss={train_loss:.4f} | Train ACC={train_acc:6.2f}% | Val ACC={val_acc:6.2f}% | Robust={rob_acc:6.2f}% | Time={elapsed:.1f}s")

    torch.save(KDkATDP.state_dict(), drive_results_dir_tt / f"{dataset_name}_KDkATDP_best.pth")
    SAVED_MODELS_TABTEXT[dataset_name]["KDk+AT+DP"] = f"{dataset_name}_KDkATDP_best.pth"

    del OA, KDk, KDkAT, KDkATDP, opt
    torch.cuda.empty_cache()

print("\n" + "="*70 + "\nTraining completed for all Text/Tabular datasets!")

Starting Text/Tabular VFL Training with Final Optimizations

YAHOO – Phase 1: Training OA (Teacher)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 178.23it/s, loss=1.1236]


  Epoch 01/10 | Train Loss=0.9977 | Train ACC= 69.01% | Val Loss=0.8636 | Val ACC= 72.15% | Time=5.1s
    → Best model saved (Val ACC: 72.15%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 158.61it/s, loss=0.9291]


  Epoch 02/10 | Train Loss=0.8240 | Train ACC= 73.30% | Val Loss=0.8450 | Val ACC= 72.67% | Time=5.5s
    → Best model saved (Val ACC: 72.67%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 182.73it/s, loss=0.8017]


  Epoch 03/10 | Train Loss=0.7690 | Train ACC= 74.89% | Val Loss=0.8456 | Val ACC= 72.45% | Time=4.8s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 169.90it/s, loss=0.7078]


  Epoch 04/10 | Train Loss=0.7275 | Train ACC= 75.82% | Val Loss=0.8612 | Val ACC= 72.38% | Time=5.3s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 166.15it/s, loss=0.6756]


  Epoch 05/10 | Train Loss=0.6908 | Train ACC= 76.79% | Val Loss=0.8633 | Val ACC= 72.52% | Time=5.2s
  Early stopping triggered.

YAHOO – Phase 2: Training KDk


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 149.28it/s, loss=0.4058]


  Epoch 01/10 | Train Loss=0.9128 | Train ACC= 64.75% | Val Loss=0.8892 | Val ACC= 71.62% | Time=6.4s
    → Best model saved (Val ACC: 71.62%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 121.38it/s, loss=0.4166]


  Epoch 02/10 | Train Loss=0.3981 | Train ACC= 72.66% | Val Loss=0.8622 | Val ACC= 72.29% | Time=7.0s
    → Best model saved (Val ACC: 72.29%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 152.48it/s, loss=0.3493]


  Epoch 03/10 | Train Loss=0.3474 | Train ACC= 73.57% | Val Loss=0.8507 | Val ACC= 72.54% | Time=5.7s
    → Best model saved (Val ACC: 72.54%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 134.05it/s, loss=0.2611]


  Epoch 04/10 | Train Loss=0.3249 | Train ACC= 74.31% | Val Loss=0.8481 | Val ACC= 72.54% | Time=6.4s
    → Best model saved (Val ACC: 72.54%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:04<00:00, 156.74it/s, loss=0.2988]


  Epoch 05/10 | Train Loss=0.3136 | Train ACC= 74.93% | Val Loss=0.8423 | Val ACC= 72.67% | Time=5.5s
    → Best model saved (Val ACC: 72.67%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 138.02it/s, loss=0.3130]


  Epoch 06/10 | Train Loss=0.3035 | Train ACC= 75.25% | Val Loss=0.8441 | Val ACC= 72.61% | Time=6.3s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 153.70it/s, loss=0.3077]


  Epoch 07/10 | Train Loss=0.2959 | Train ACC= 75.83% | Val Loss=0.8408 | Val ACC= 72.76% | Time=5.6s
    → Best model saved (Val ACC: 72.76%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 146.40it/s, loss=0.3464]


  Epoch 08/10 | Train Loss=0.2911 | Train ACC= 76.37% | Val Loss=0.8399 | Val ACC= 72.76% | Time=6.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 141.29it/s, loss=0.2419]


  Epoch 09/10 | Train Loss=0.2863 | Train ACC= 76.42% | Val Loss=0.8407 | Val ACC= 72.84% | Time=6.1s
    → Best model saved (Val ACC: 72.84%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:05<00:00, 151.54it/s, loss=0.1826]


  Epoch 10/10 | Train Loss=0.2811 | Train ACC= 76.83% | Val Loss=0.8396 | Val ACC= 72.73% | Time=5.8s

YAHOO – Phase 3: Training KDk+AT


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 109.11it/s, loss=0.4326]


  Epoch 01/10 | Train Loss=0.9852 | Train ACC= 59.92% | Val ACC= 71.17% | Robust= 68.05% | Time=8.8s
    → Best model saved (Robust ACC: 68.05%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 109.98it/s, loss=0.4243]


  Epoch 02/10 | Train Loss=0.4330 | Train ACC= 69.88% | Val ACC= 71.98% | Robust= 69.20% | Time=8.9s
    → Best model saved (Robust ACC: 69.20%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 118.42it/s, loss=0.3322]


  Epoch 03/10 | Train Loss=0.3782 | Train ACC= 71.11% | Val ACC= 72.27% | Robust= 69.64% | Time=8.2s
    → Best model saved (Robust ACC: 69.64%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 108.01it/s, loss=0.3991]


  Epoch 04/10 | Train Loss=0.3548 | Train ACC= 71.74% | Val ACC= 72.45% | Robust= 70.12% | Time=8.8s
    → Best model saved (Robust ACC: 70.12%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 110.39it/s, loss=0.3378]


  Epoch 05/10 | Train Loss=0.3435 | Train ACC= 72.19% | Val ACC= 72.61% | Robust= 70.12% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 119.37it/s, loss=0.3430]


  Epoch 06/10 | Train Loss=0.3321 | Train ACC= 72.75% | Val ACC= 72.65% | Robust= 70.34% | Time=8.1s
    → Best model saved (Robust ACC: 70.34%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 105.76it/s, loss=0.3072]


  Epoch 07/10 | Train Loss=0.3258 | Train ACC= 72.99% | Val ACC= 72.64% | Robust= 70.37% | Time=9.0s
    → Best model saved (Robust ACC: 70.37%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 112.60it/s, loss=0.3421]


  Epoch 08/10 | Train Loss=0.3186 | Train ACC= 73.48% | Val ACC= 72.75% | Robust= 70.43% | Time=9.1s
    → Best model saved (Robust ACC: 70.43%)


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 118.96it/s, loss=0.3403]


  Epoch 09/10 | Train Loss=0.3145 | Train ACC= 73.69% | Val ACC= 72.87% | Robust= 70.42% | Time=8.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 106.25it/s, loss=0.2893]


  Epoch 10/10 | Train Loss=0.3107 | Train ACC= 73.80% | Val ACC= 72.83% | Robust= 70.43% | Time=8.9s

YAHOO – Phase 4: Training KDk+AT+DP
  Applying DP with grad_noise: 4.0, embed_noise: 0.3, LR: 4.00e-06, Epochs: 20


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 110.94it/s, loss=0.2924]


  Epoch 01/20 | Train Loss=0.3132 | Train ACC= 73.89% | Val ACC= 72.80% | Robust= 70.53% | Time=9.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 113.71it/s, loss=0.3925]


  Epoch 02/20 | Train Loss=0.3126 | Train ACC= 73.95% | Val ACC= 72.76% | Robust= 70.58% | Time=8.5s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 104.44it/s, loss=0.3547]


  Epoch 03/20 | Train Loss=0.3105 | Train ACC= 73.86% | Val ACC= 72.86% | Robust= 70.58% | Time=9.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 113.93it/s, loss=0.2897]


  Epoch 04/20 | Train Loss=0.3118 | Train ACC= 73.94% | Val ACC= 72.74% | Robust= 70.51% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 113.71it/s, loss=0.3339]


  Epoch 05/20 | Train Loss=0.3102 | Train ACC= 74.10% | Val ACC= 72.81% | Robust= 70.61% | Time=8.5s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 104.53it/s, loss=0.3714]


  Epoch 06/20 | Train Loss=0.3112 | Train ACC= 73.96% | Val ACC= 72.72% | Robust= 70.48% | Time=9.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 112.34it/s, loss=0.2770]


  Epoch 07/20 | Train Loss=0.3096 | Train ACC= 74.08% | Val ACC= 72.80% | Robust= 70.56% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 108.14it/s, loss=0.3420]


  Epoch 08/20 | Train Loss=0.3107 | Train ACC= 73.95% | Val ACC= 72.84% | Robust= 70.55% | Time=8.8s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 102.74it/s, loss=0.2755]


  Epoch 09/20 | Train Loss=0.3097 | Train ACC= 74.03% | Val ACC= 72.67% | Robust= 70.56% | Time=9.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 112.02it/s, loss=0.3825]


  Epoch 10/20 | Train Loss=0.3098 | Train ACC= 73.90% | Val ACC= 72.80% | Robust= 70.38% | Time=9.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 113.11it/s, loss=0.4244]


  Epoch 11/20 | Train Loss=0.3094 | Train ACC= 74.10% | Val ACC= 72.84% | Robust= 70.48% | Time=8.5s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 103.36it/s, loss=0.3541]


  Epoch 12/20 | Train Loss=0.3088 | Train ACC= 74.11% | Val ACC= 72.78% | Robust= 70.49% | Time=9.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 111.17it/s, loss=0.3685]


  Epoch 13/20 | Train Loss=0.3087 | Train ACC= 74.10% | Val ACC= 72.80% | Robust= 70.33% | Time=9.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 115.11it/s, loss=0.3209]


  Epoch 14/20 | Train Loss=0.3096 | Train ACC= 74.08% | Val ACC= 72.78% | Robust= 70.38% | Time=8.4s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 104.62it/s, loss=0.3068]


  Epoch 15/20 | Train Loss=0.3078 | Train ACC= 74.19% | Val ACC= 72.87% | Robust= 70.48% | Time=9.1s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 111.01it/s, loss=0.3321]


  Epoch 16/20 | Train Loss=0.3083 | Train ACC= 74.14% | Val ACC= 72.78% | Robust= 70.39% | Time=9.2s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 114.41it/s, loss=0.3515]


  Epoch 17/20 | Train Loss=0.3085 | Train ACC= 74.27% | Val ACC= 72.92% | Robust= 70.50% | Time=8.4s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:07<00:00, 104.99it/s, loss=0.3221]


  Epoch 18/20 | Train Loss=0.3081 | Train ACC= 74.05% | Val ACC= 72.74% | Robust= 70.42% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 113.79it/s, loss=0.2546]


  Epoch 19/20 | Train Loss=0.3085 | Train ACC= 74.22% | Val ACC= 72.79% | Robust= 70.59% | Time=9.0s


Training YAHOO: 100%|█████████████████████████████████████████| 781/781 [00:06<00:00, 115.14it/s, loss=0.3920]


  Epoch 20/20 | Train Loss=0.3079 | Train ACC= 74.28% | Val ACC= 72.78% | Robust= 70.41% | Time=8.4s

CRITEO – Phase 1: Training OA (Teacher)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 182.20it/s, loss=0.5629]


  Epoch 01/15 | Train Loss=0.5274 | Train ACC= 75.01% | Val Loss=0.5220 | Val ACC= 75.16% | Time=7.3s
    → Best model saved (Val ACC: 75.16%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 199.75it/s, loss=0.4989]


  Epoch 02/15 | Train Loss=0.5196 | Train ACC= 75.27% | Val Loss=0.5184 | Val ACC= 75.28% | Time=6.7s
    → Best model saved (Val ACC: 75.28%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 177.20it/s, loss=0.6676]


  Epoch 03/15 | Train Loss=0.5179 | Train ACC= 75.46% | Val Loss=0.5177 | Val ACC= 75.33% | Time=7.5s
    → Best model saved (Val ACC: 75.33%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 198.14it/s, loss=0.5010]


  Epoch 04/15 | Train Loss=0.5171 | Train ACC= 75.55% | Val Loss=0.5173 | Val ACC= 75.38% | Time=6.7s
    → Best model saved (Val ACC: 75.38%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 186.06it/s, loss=0.5741]


  Epoch 05/15 | Train Loss=0.5159 | Train ACC= 75.63% | Val Loss=0.5164 | Val ACC= 75.50% | Time=7.3s
    → Best model saved (Val ACC: 75.50%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 196.05it/s, loss=0.4469]


  Epoch 06/15 | Train Loss=0.5155 | Train ACC= 75.74% | Val Loss=0.5142 | Val ACC= 75.55% | Time=6.8s
    → Best model saved (Val ACC: 75.55%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 182.96it/s, loss=0.5551]


  Epoch 07/15 | Train Loss=0.5151 | Train ACC= 75.71% | Val Loss=0.5154 | Val ACC= 75.72% | Time=7.5s
    → Best model saved (Val ACC: 75.72%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 192.11it/s, loss=0.4954]


  Epoch 08/15 | Train Loss=0.5146 | Train ACC= 75.76% | Val Loss=0.5148 | Val ACC= 75.63% | Time=6.9s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 184.09it/s, loss=0.5527]


  Epoch 09/15 | Train Loss=0.5142 | Train ACC= 75.73% | Val Loss=0.5159 | Val ACC= 75.58% | Time=7.4s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 198.86it/s, loss=0.4917]


  Epoch 10/15 | Train Loss=0.5137 | Train ACC= 75.79% | Val Loss=0.5150 | Val ACC= 75.73% | Time=6.7s
    → Best model saved (Val ACC: 75.73%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 182.51it/s, loss=0.4697]


  Epoch 11/15 | Train Loss=0.5139 | Train ACC= 75.84% | Val Loss=0.5147 | Val ACC= 75.66% | Time=7.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 196.79it/s, loss=0.4089]


  Epoch 12/15 | Train Loss=0.5134 | Train ACC= 75.93% | Val Loss=0.5149 | Val ACC= 75.61% | Time=6.8s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:06<00:00, 185.30it/s, loss=0.5632]


  Epoch 13/15 | Train Loss=0.5139 | Train ACC= 75.83% | Val Loss=0.5142 | Val ACC= 75.56% | Time=7.4s
  Early stopping triggered.

CRITEO – Phase 2: Training KDk


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 169.28it/s, loss=0.1248]


  Epoch 01/15 | Train Loss=0.1935 | Train ACC= 73.16% | Val Loss=0.5206 | Val ACC= 75.30% | Time=7.8s
    → Best model saved (Val ACC: 75.30%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 154.69it/s, loss=0.1636]


  Epoch 02/15 | Train Loss=0.1636 | Train ACC= 75.34% | Val Loss=0.5206 | Val ACC= 75.48% | Time=8.5s
    → Best model saved (Val ACC: 75.48%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 169.29it/s, loss=0.1463]


  Epoch 03/15 | Train Loss=0.1618 | Train ACC= 75.49% | Val Loss=0.5176 | Val ACC= 75.57% | Time=7.8s
    → Best model saved (Val ACC: 75.57%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 154.28it/s, loss=0.1661]


  Epoch 04/15 | Train Loss=0.1606 | Train ACC= 75.52% | Val Loss=0.5251 | Val ACC= 74.82% | Time=8.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 160.21it/s, loss=0.1729]


  Epoch 05/15 | Train Loss=0.1598 | Train ACC= 75.56% | Val Loss=0.5173 | Val ACC= 75.67% | Time=8.4s
    → Best model saved (Val ACC: 75.67%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 161.37it/s, loss=0.1736]


  Epoch 06/15 | Train Loss=0.1590 | Train ACC= 75.67% | Val Loss=0.5156 | Val ACC= 75.56% | Time=8.2s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:08<00:00, 153.48it/s, loss=0.1788]


  Epoch 07/15 | Train Loss=0.1584 | Train ACC= 75.61% | Val Loss=0.5165 | Val ACC= 75.47% | Time=8.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:07<00:00, 165.23it/s, loss=0.1719]


  Epoch 08/15 | Train Loss=0.1581 | Train ACC= 75.74% | Val Loss=0.5157 | Val ACC= 75.56% | Time=8.0s
  Early stopping triggered.

CRITEO – Phase 3: Training KDk+AT


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 122.16it/s, loss=0.1548]


  Epoch 01/15 | Train Loss=0.1884 | Train ACC= 74.08% | Val ACC= 74.82% | Robust= 74.41% | Time=11.6s
    → Best model saved (Robust ACC: 74.41%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 123.57it/s, loss=0.2063]


  Epoch 02/15 | Train Loss=0.1743 | Train ACC= 74.47% | Val ACC= 74.83% | Robust= 74.52% | Time=11.6s
    → Best model saved (Robust ACC: 74.52%)


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 121.63it/s, loss=0.1768]


  Epoch 03/15 | Train Loss=0.1721 | Train ACC= 74.42% | Val ACC= 75.31% | Robust= 74.31% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 127.14it/s, loss=0.1685]


  Epoch 04/15 | Train Loss=0.1710 | Train ACC= 74.38% | Val ACC= 75.33% | Robust= 74.28% | Time=11.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 130.64it/s, loss=0.2014]


  Epoch 05/15 | Train Loss=0.1702 | Train ACC= 74.50% | Val ACC= 75.27% | Robust= 74.38% | Time=11.2s
  Early stopping triggered.

CRITEO – Phase 4: Training KDk+AT+DP
  Applying DP with grad_noise: 1.2, embed_noise: 0.1, LR: 5.00e-05, Epochs: 15


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 123.13it/s, loss=0.1474]


  Epoch 01/15 | Train Loss=0.1724 | Train ACC= 74.45% | Val ACC= 75.13% | Robust= 74.30% | Time=11.4s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 121.23it/s, loss=0.1748]


  Epoch 02/15 | Train Loss=0.1724 | Train ACC= 74.44% | Val ACC= 75.13% | Robust= 74.33% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 121.33it/s, loss=0.1748]


  Epoch 03/15 | Train Loss=0.1722 | Train ACC= 74.37% | Val ACC= 75.03% | Robust= 74.36% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 121.30it/s, loss=0.1444]


  Epoch 04/15 | Train Loss=0.1721 | Train ACC= 74.44% | Val ACC= 75.32% | Robust= 74.22% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 124.84it/s, loss=0.1586]


  Epoch 05/15 | Train Loss=0.1720 | Train ACC= 74.38% | Val ACC= 75.29% | Robust= 74.30% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 128.95it/s, loss=0.1727]


  Epoch 06/15 | Train Loss=0.1717 | Train ACC= 74.39% | Val ACC= 75.27% | Robust= 74.26% | Time=11.4s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 124.80it/s, loss=0.1837]


  Epoch 07/15 | Train Loss=0.1716 | Train ACC= 74.41% | Val ACC= 75.12% | Robust= 74.26% | Time=11.3s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 122.29it/s, loss=0.1890]


  Epoch 08/15 | Train Loss=0.1716 | Train ACC= 74.47% | Val ACC= 75.29% | Robust= 74.31% | Time=11.5s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 121.67it/s, loss=0.1532]


  Epoch 09/15 | Train Loss=0.1714 | Train ACC= 74.41% | Val ACC= 75.28% | Robust= 74.33% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 120.52it/s, loss=0.1862]


  Epoch 10/15 | Train Loss=0.1713 | Train ACC= 74.38% | Val ACC= 74.92% | Robust= 74.54% | Time=11.7s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 124.39it/s, loss=0.2162]


  Epoch 11/15 | Train Loss=0.1713 | Train ACC= 74.39% | Val ACC= 75.29% | Robust= 74.27% | Time=11.7s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 126.50it/s, loss=0.1849]


  Epoch 12/15 | Train Loss=0.1713 | Train ACC= 74.45% | Val ACC= 75.33% | Robust= 74.42% | Time=11.6s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:09<00:00, 125.59it/s, loss=0.1866]


  Epoch 13/15 | Train Loss=0.1711 | Train ACC= 74.39% | Val ACC= 75.32% | Robust= 74.30% | Time=11.3s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 119.64it/s, loss=0.2065]


  Epoch 14/15 | Train Loss=0.1710 | Train ACC= 74.42% | Val ACC= 75.16% | Robust= 74.34% | Time=11.7s


Training CRITEO: 100%|██████████████████████████████████████| 1250/1250 [00:10<00:00, 119.87it/s, loss=0.1811]


  Epoch 15/15 | Train Loss=0.1709 | Train ACC= 74.48% | Val ACC= 75.28% | Robust= 74.38% | Time=11.7s

Training completed for all Text/Tabular datasets!


In [None]:
# Cell 7 - Focused Evaluation for Text/Tabular
from sklearn.metrics import f1_score, roc_auc_score
import pandas as pd
from IPython.display import display
import numpy as np
from pathlib import Path
import torch

all_evaluation_results = []
print("--- FOCUSED EVALUATION INITIATED for Text/Tabular ---")

SAVED_MODELS_TABTEXT = {
    "YAHOO": { "OA": "YAHOO_OA_best.pth", "KDk": "YAHOO_KDk_best.pth", "KDk+AT": "YAHOO_KDkAT_best.pth", "KDk+AT+DP": "YAHOO_KDkATDP_best.pth" },
    "CRITEO": { "OA": "CRITEO_OA_best.pth", "KDk": "CRITEO_KDk_best.pth", "KDk+AT": "CRITEO_KDkAT_best.pth", "KDk+AT+DP": "CRITEO_KDkATDP_best.pth" }
}

model_group = { "type": "TABULAR_TEXT", "models": SAVED_MODELS_TABTEXT, "loaders": TABULAR_TEXT_LOADERS, "builder": build_vfl_model_tabular_text }
model_type_name = model_group["type"]
drive_dir = Path(CONFIG["DRIVE_PATH"]) / "VFL_Results" / model_type_name

for dataset_name, models in model_group["models"].items():
    loaders = model_group["loaders"][dataset_name]
    tr_loader, te_loader = loaders["train"], loaders["test"]
    num_classes = loaders["num_classes"]
    builder_fn = model_group["builder"]

    for model_type, model_path in models.items():
        print(f"Evaluating {dataset_name} - {model_type}...")
        model = builder_fn(dataset_name, num_classes)
        model.load_state_dict(torch.load(drive_dir / model_path, map_location=CONFIG["DEVICE"]))

        _, clean_acc_top1, _ = evaluate_model_unified(model, te_loader, CONFIG, dataset_name)
        _, robust_acc_top1, _ = evaluate_model_unified(model, te_loader, CONFIG, dataset_name, robust=True)

        y_true_list, y_probs_list = [], []
        with torch.no_grad():
            for (xa, xp), y in te_loader:
                xa, xp = xa.to(CONFIG["DEVICE"]), xp.to(CONFIG["DEVICE"])
                logits = model(xa, xp)
                y_true_list.extend(y.numpy())
                y_probs_list.extend(torch.softmax(logits, dim=1).cpu().numpy())
        y_true, y_probs = np.array(y_true_list), np.array(y_probs_list)
        y_pred = np.argmax(y_probs, axis=1)

        f1 = f1_score(y_true, y_pred, average='binary' if num_classes == 2 else 'macro') * 100
        auc = roc_auc_score(y_true, y_probs[:, 1]) * 100 if num_classes == 2 else roc_auc_score(y_true, y_probs, multi_class='ovr') * 100

        attack_results = {}
        for attack in ["passive", "direct", "active", "perturbed"]:
            a1, _ = run_privacy_attack_unified( model, tr_loader, te_loader, CONFIG, num_classes, attack_type=attack, aux_batches_limit=4, attacker_epochs=2)
            attack_results[attack] = {"top1": a1}

        pli = (1 - attack_results["passive"]["top1"]/100.0) * (robust_acc_top1 / max(clean_acc_top1,1)) * 100

        all_evaluation_results.append({
            "Dataset": dataset_name, "Model": model_type,
            "Clean ACC (Top-1 %)": round(clean_acc_top1, 2),
            "Robust ACC (Top-1 %)": round(robust_acc_top1, 2),
            "F1 Score (%)": round(f1, 2),
            "AUC (%)": round(auc, 2),
            "ASR_passive (Top-1 %)": round(attack_results["passive"]["top1"], 2),
            "ASR_direct (Top-1 %)": round(attack_results["direct"]["top1"], 2),
            "ASR_active (Top-1 %)": round(attack_results["active"]["top1"], 2),
            "ASR_perturbed (Top-1 %)": round(attack_results["perturbed"]["top1"], 2),
            "Privacy Leakage Index": round(pli, 2)
        })

final_df = pd.DataFrame(all_evaluation_results)
final_path = drive_dir / "evaluation_results_TextTabular_final.csv"
final_df.to_csv(final_path, index=False)
print("\n--- EVALUATION COMPLETE ---"); display(final_df); print(f"\nText/Tabular results saved to: {final_path}")

--- FOCUSED EVALUATION INITIATED for Text/Tabular ---
Evaluating YAHOO - OA...
Evaluating YAHOO - KDk...
Evaluating YAHOO - KDk+AT...
Evaluating YAHOO - KDk+AT+DP...
Evaluating CRITEO - OA...
Evaluating CRITEO - KDk...
Evaluating CRITEO - KDk+AT...
Evaluating CRITEO - KDk+AT+DP...

--- EVALUATION COMPLETE ---


Unnamed: 0,Dataset,Model,Clean ACC (Top-1 %),Robust ACC (Top-1 %),F1 Score (%),AUC (%),ASR_passive (Top-1 %),ASR_direct (Top-1 %),ASR_active (Top-1 %),ASR_perturbed (Top-1 %),Privacy Leakage Index
0,YAHOO,OA,72.67,67.49,72.05,95.13,15.15,19.73,20.13,13.08,78.8
1,YAHOO,KDk,72.84,66.02,72.28,95.2,18.2,23.88,10.98,12.45,74.14
2,YAHOO,KDk+AT,72.75,70.43,72.17,95.19,14.41,21.61,24.98,9.71,82.86
3,YAHOO,KDk+AT+DP,72.78,70.41,72.18,95.2,14.71,18.35,10.14,10.3,82.51
4,CRITEO,OA,75.74,71.58,29.84,70.87,66.75,73.84,62.09,41.97,31.43
5,CRITEO,KDk,75.67,72.69,23.01,70.44,56.25,73.14,70.25,74.3,42.03
6,CRITEO,KDk+AT,74.83,74.52,9.49,69.57,74.47,74.16,74.44,74.47,25.43
7,CRITEO,KDk+AT+DP,75.28,74.38,17.65,69.6,74.26,74.47,71.88,74.44,25.44



Text/Tabular results saved to: /content/VFL_Results/TABULAR_TEXT/evaluation_results_TextTabular_final.csv
