In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms

import timm
from tqdm import tqdm
import os
import json
import datetime
import matplotlib.pyplot as plt
from typing import List, Optional

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
if device == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
GPU: NVIDIA GeForce RTX 4060 Ti


In [2]:
import random
GLOBAL_SEED = 2

def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # 让 cudnn 更可复现（会稍微慢一点）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_global_seed(GLOBAL_SEED)

In [3]:
try:
    GLOBAL_SEED
except NameError:
    GLOBAL_SEED = 42

def get_dataset(dataset_name, data_dir="./data", img_size=224):
    name = dataset_name.lower()

    if name == "imagenet":
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        transform_val = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        train_set = datasets.ImageFolder(f"{data_dir}/train", transform=transform_train)
        val_set = datasets.ImageFolder(f"{data_dir}/val", transform=transform_val)
        num_classes = len(train_set.classes)

    elif name == "cifar10":
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.5, 0.5, 0.5),
                std=(0.5, 0.5, 0.5),
            )
        ])
        train_set = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
        val_set = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
        num_classes = 10

    elif name == "cifar100":
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.5, 0.5, 0.5),
                std=(0.5, 0.5, 0.5),
            )
        ])
        train_set = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=transform)
        val_set = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform)
        num_classes = 100

    elif name == "svhn":
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.5, 0.5, 0.5),
                std=(0.5, 0.5, 0.5),
            )
        ])
        train_set = datasets.SVHN(root=data_dir, split='train', download=True, transform=transform)
        val_set   = datasets.SVHN(root=data_dir, split='test',  download=True, transform=transform)
        num_classes = 10

    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    return train_set, val_set, num_classes





def get_loaders(dataset_name, data_dir="./data", batch_size=64, img_size=224, num_workers=8):
    train_set, val_set, num_classes = get_dataset(dataset_name, data_dir, img_size)

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=(device == "cuda"),
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=(device == "cuda"),
    )

    return train_loader, val_loader, num_classes


In [4]:
from tqdm import tqdm
import torch

def train_one_epoch(model, loader, criterion, optimizer, device, epoch, epochs, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0

    use_amp = scaler is not None

    loop = tqdm(loader, desc=f"Train [{epoch+1}/{epochs}]", ncols=100)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        if use_amp:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(loader)
    acc = correct / len(loader.dataset)
    return avg_loss, acc


def eval_one_epoch(model, loader, criterion, device, epoch, epochs):
    model.eval()
    total_loss = 0.0
    correct = 0

    loop = tqdm(loader, desc=f"Val   [{epoch+1}/{epochs}]", ncols=100)
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()

            loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(loader)
    acc = correct / len(loader.dataset)
    return avg_loss, acc


In [5]:
# ============================================================
#  1. 谱分析基础工具：svd + effective rank + 画图
# ============================================================

def to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.array(x)


def compute_singular_values_from_qk(
    q: torch.Tensor,
    k: torch.Tensor,
    scale: float = 1.0,
    top_k: int = 32,
    batch_agg: str = "mean",
) -> torch.Tensor:
    """
    给定某一层的 Q, K，计算每个 head 的 pre-softmax S = (Q * scale) @ K^T 的前 top_k 个奇异值。

    参数：
      q, k: 形状 (B, H, N, D_h)
      scale: 一般 = 1/sqrt(D_h)
      top_k: 只保留前 top_k 个最大奇异值
      batch_agg:
        - "mean": S = mean_b (Q_b K_b^T)
        - "first": 只用第一个 batch

    返回：
      svals: (H, top_k)
    """
    assert q.shape == k.shape, "q, k 必须同形状 (B, H, N, D_h)"
    B, H, N, Dh = q.shape
    top_k = min(top_k, N)

    s_list = []

    for h in range(H):
        if batch_agg == "first":
            q_h = q[0, h] * scale      # (N, D_h)
            k_h = k[0, h]              # (N, D_h)
            S = q_h @ k_h.transpose(0, 1)  # (N, N)
        else:
            S = None
            for b in range(B):
                q_bh = q[b, h] * scale
                k_bh = k[b, h]
                S_b = q_bh @ k_bh.transpose(0, 1)
                if S is None:
                    S = S_b
                else:
                    S = S + S_b
            S = S / float(B)

        svals_full = torch.linalg.svdvals(S)   # (N,)
        svals_head = svals_full[:top_k]        # (top_k,)
        s_list.append(svals_head)

    svals = torch.stack(s_list, dim=0)  # (H, top_k)
    return svals


def normalize_singular_values(
    svals: torch.Tensor,
    mode: str = "sum",
    eps: float = 1e-12,
) -> torch.Tensor:
    """
    对奇异值做归一化，方便热力图看“谱形状”。

    mode:
      - "sum": 除以 sum(s_i) → 类似概率分布
      - "max": 除以 max(s_i) → 最大值变 1
      - "none": 不归一化
    """
    if mode == "none":
        return svals

    if mode == "sum":
        denom = svals.sum(dim=-1, keepdim=True) + eps
    elif mode == "max":
        denom = svals.max(dim=-1, keepdim=True).values + eps
    else:
        raise ValueError(f"Unknown normalize mode: {mode}")

    return svals / denom


def compute_effective_rank(
    svals: torch.Tensor,
    eps: float = 1e-12,
) -> torch.Tensor:
    """
    根据奇异值计算 effective rank：

      p_i = s_i / sum_j s_j
      H = - sum_i p_i log p_i
      erank = exp(H)

    参数：
      svals: (..., K)
    返回：
      erank: (...,)
    """
    svals = torch.clamp(svals, min=eps)
    S = svals.sum(dim=-1, keepdim=True) + eps
    p = svals / S
    log_p = torch.log(p + eps)
    H = -(p * log_p).sum(dim=-1)
    erank = torch.exp(H)
    return erank


def stack_svals_over_epochs(svals_per_epoch):
    """
    svals_per_epoch: list of (H, K)
    返回: (E, H, K)
    """
    assert len(svals_per_epoch) > 0
    H, K = svals_per_epoch[0].shape
    for sv in svals_per_epoch:
        assert sv.shape == (H, K)
    return torch.stack(svals_per_epoch, dim=0)  # (E,H,K)


def plot_svals_heatmap_per_head(
    svals_epochs: torch.Tensor,
    head_idx: int,
    epoch_indices=None,
    normalize_mode: str = "sum",
    cmap: str = "viridis",
    save_path: str = None,
    title: str = None,
):
    """
    画某一个 head 的 singular value 热力图。

    svals_epochs: (E, H, K)
    """
    E, H, K = svals_epochs.shape
    assert 0 <= head_idx < H

    svals_head = svals_epochs[:, head_idx, :]  # (E, K)
    svals_head = normalize_singular_values(svals_head, mode=normalize_mode)
    data = to_numpy(svals_head).T  # (K, E)

    if epoch_indices is None:
        epoch_indices = list(range(1, E + 1))

    plt.figure(figsize=(6, 4))
    im = plt.imshow(
        data,
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        cmap=cmap,
    )
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xlabel("Epoch")
    plt.ylabel("Singular value index (1..K)")
    if title is None:
        title = f"Head {head_idx}: singular values heatmap"
    plt.title(title)

    xticks_pos = np.linspace(0, E - 1, num=min(E, 6), dtype=int)
    xticks_lbl = [str(epoch_indices[i]) for i in xticks_pos]
    plt.xticks(ticks=xticks_pos, labels=xticks_lbl)

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=150)
    plt.close()


def plot_svals_heatmap_head_mean(
    svals_epochs: torch.Tensor,
    epoch_indices=None,
    normalize_mode: str = "sum",
    cmap: str = "viridis",
    save_path: str = None,
    title: str = None,
):
    """
    对所有 head 的奇异值先归一化 + 求均值，然后画 head-mean 热力图。

    svals_epochs: (E, H, K)
    """
    E, H, K = svals_epochs.shape

    svals_norm = normalize_singular_values(svals_epochs, mode=normalize_mode)  # (E,H,K)
    svals_mean = svals_norm.mean(dim=1)  # (E,K)
    data = to_numpy(svals_mean).T  # (K,E)

    if epoch_indices is None:
        epoch_indices = list(range(1, E + 1))

    plt.figure(figsize=(6, 4))
    im = plt.imshow(
        data,
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        cmap=cmap,
    )
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xlabel("Epoch")
    plt.ylabel("Singular value index (1..K)")
    if title is None:
        title = "Head-mean singular values heatmap"
    plt.title(title)

    xticks_pos = np.linspace(0, E - 1, num=min(E, 6), dtype=int)
    xticks_lbl = [str(epoch_indices[i]) for i in xticks_pos]
    plt.xticks(ticks=xticks_pos, labels=xticks_lbl)

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=150)
    plt.close()


def plot_effective_rank_curves(
    svals_epochs: torch.Tensor,
    epoch_indices=None,
    save_path: str = None,
    title: str = None,
    show_per_head: bool = True,
    show_mean: bool = True,
):
    """
    画 effective rank 随 epoch 的变化曲线。

    svals_epochs: (E,H,K)
    """
    E, H, K = svals_epochs.shape

    if epoch_indices is None:
        epoch_indices = list(range(1, E + 1))

    erank_epochs = compute_effective_rank(svals_epochs)  # (E,H)
    erank_np = to_numpy(erank_epochs)  # (E,H)

    plt.figure(figsize=(6, 4))

    if show_per_head:
        for h in range(H):
            # 只给第一条线加 label，避免 legend 太乱
            lbl = f"head {h}" if h == 0 else None
            plt.plot(epoch_indices, erank_np[:, h], alpha=0.3, linewidth=1.0, label=lbl)

    if show_mean:
        mean_erank = erank_np.mean(axis=1)
        std_erank = erank_np.std(axis=1)
        plt.plot(epoch_indices, mean_erank, color="black", linewidth=2.0, label="mean erank")
        if H > 1:
            plt.fill_between(
                epoch_indices,
                mean_erank - std_erank,
                mean_erank + std_erank,
                color="gray",
                alpha=0.2,
                label="±1 std",
            )

    plt.xlabel("Epoch")
    plt.ylabel("Effective rank")
    if title is None:
        title = "Effective rank over epochs"
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path, dpi=150)
    plt.close()

# ============================================================
#  2. 从 ViT 指定 block 抓 Q/K → 奇异值
# ============================================================


# ============================================================
#  3. 训练 + 谱收集：单个实验 (MHA or TPA)
# ============================================================
def _collect_block_svals_one_epoch(
    model,
    val_loader,
    block_idx: int,
    top_k: int,
    num_batches_spec: int,
    device: str,
):
    """
    返回四个：(svals_qk, svals_q, svals_k, svals_v)
      - svals_qk: pre-softmax logits (QK^T) 的奇异值 (H, Kqk)
      - svals_q/k/v: Q/K/V (N x Dh) 的奇异值 (H, Kqkv)
    说明：SVD 放 CPU 上算，避免训练时 GPU 内存紧张导致 OOM。
    """
    model.eval()
    attn_mod = model.vit.blocks[block_idx].attn

    cached_x = {}

    def _hook(module, inputs, output):
        cached_x["x"] = inputs[0].detach()

    handle = attn_mod.register_forward_hook(_hook)

    qk_batches = []
    q_batches  = []
    k_batches  = []
    v_batches  = []
    batches_done = 0

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)
            cached_x.clear()

            _ = model(images)

            if "x" not in cached_x:
                continue

            x = cached_x["x"]  # (B, N, C)
            B, N, C = x.shape

            # -------- 拿 q,k,v --------
            if hasattr(attn_mod, "qkv"):
                # timm MHA
                qkv = attn_mod.qkv(x)  # (B, N, 3*C)
                H = attn_mod.num_heads
                Dh = C // H
                qkv = qkv.reshape(B, N, 3, H, Dh).permute(2, 0, 3, 1, 4)  # (3, B, H, N, Dh)
                q, k, v = qkv[0], qkv[1], qkv[2]  # (B,H,N,Dh)
                scale = attn_mod.scale
            else:
                # TPA 系列
                q, k, v = attn_mod._make_qkv(x)  # (B,H,N,Dh)
                Dh = q.shape[-1]
                scale = attn_mod.scale if hasattr(attn_mod, "scale") else (Dh ** -0.5)

            H = q.shape[1]
            K_qk  = min(top_k, N)          # QK^T 是 N×N
            K_qkv = min(top_k, min(N, Dh)) # Q/K/V 是 N×Dh

            # -------- 1) QK^T 的奇异值 --------
            attn_logits = (q * scale) @ k.transpose(-2, -1)   # (B,H,N,N)
            attn_mean = attn_logits.mean(dim=0)               # (H,N,N)

            svals_qk = torch.zeros(H, K_qk, dtype=torch.float32)
            for h in range(H):
                A = attn_mean[h].float().cpu()
                sv = torch.linalg.svdvals(A)
                svals_qk[h] = sv[:K_qk]

            # -------- 2) Q/K/V 本体的奇异值：对 batch 平均后做 SVD --------
            q_mean = q.mean(dim=0)  # (H,N,Dh)
            k_mean = k.mean(dim=0)
            v_mean = v.mean(dim=0)

            svals_q = torch.zeros(H, K_qkv, dtype=torch.float32)
            svals_k = torch.zeros(H, K_qkv, dtype=torch.float32)
            svals_v = torch.zeros(H, K_qkv, dtype=torch.float32)

            for h in range(H):
                svq = torch.linalg.svdvals(q_mean[h].float().cpu())
                svk = torch.linalg.svdvals(k_mean[h].float().cpu())
                svv = torch.linalg.svdvals(v_mean[h].float().cpu())
                svals_q[h] = svq[:K_qkv]
                svals_k[h] = svk[:K_qkv]
                svals_v[h] = svv[:K_qkv]

            qk_batches.append(svals_qk)
            q_batches.append(svals_q)
            k_batches.append(svals_k)
            v_batches.append(svals_v)

            batches_done += 1
            if batches_done >= num_batches_spec:
                break

    handle.remove()

    if len(qk_batches) == 0:
        raise RuntimeError("在谱分析时没有拿到任何 batch 的 attention。")

    # 对 num_batches_spec 求平均：-> (H,K)
    svals_qk_epoch = torch.stack(qk_batches, dim=0).mean(dim=0)
    svals_q_epoch  = torch.stack(q_batches,  dim=0).mean(dim=0)
    svals_k_epoch  = torch.stack(k_batches,  dim=0).mean(dim=0)
    svals_v_epoch  = torch.stack(v_batches,  dim=0).mean(dim=0)

    return svals_qk_epoch, svals_q_epoch, svals_k_epoch, svals_v_epoch

def compute_singular_values_from_tokens(
    x: torch.Tensor,
    top_k: int = 32,
    batch_agg: str = "mean",
) -> torch.Tensor:
    """
    x: (B, H, N, D_h)  -> 对每个 head 的 (N, D_h) 做 SVD，取 top_k
    返回: (H, K)
    """
    B, H, N, Dh = x.shape
    K = min(top_k, N, Dh)

    s_list = []
    for h in range(H):
        if batch_agg == "first":
            X = x[0, h]  # (N, Dh)
        else:
            X = x[:, h].mean(dim=0)  # (N, Dh)

        svals_full = torch.linalg.svdvals(X.to(torch.float32))  # (min(N,Dh),)
        s_list.append(svals_full[:K])

    return torch.stack(s_list, dim=0)  # (H, K)



def _collect_block_qkv_svals_one_epoch(
    model,
    val_loader,
    block_idx: int,
    top_k: int,
    num_batches_spec: int,
    device: str,
):
    """
    返回四个谱：(H,K)
      - svals_qk: SVD of pre-softmax logits (q*scale)@k^T
      - svals_q : SVD of Q token matrix (N,Dh)
      - svals_k : SVD of K token matrix (N,Dh)
      - svals_v : SVD of V token matrix (N,Dh)
    """
    model.eval()
    attn_mod = model.vit.blocks[block_idx].attn
    cached_x = {}

    def _hook(module, inputs, output):
        cached_x["x"] = inputs[0].detach()

    handle = attn_mod.register_forward_hook(_hook)

    qk_list, q_list, k_list, v_list = [], [], [], []
    batches_done = 0

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)
            cached_x.clear()
            _ = model(images)

            if "x" not in cached_x:
                continue

            x = cached_x["x"]  # (B,N,C)
            B, N, C = x.shape

            # ---- 拿 q,k,v ----
            if hasattr(attn_mod, "qkv"):
                qkv = attn_mod.qkv(x)
                H = attn_mod.num_heads
                Dh = C // H
                qkv = qkv.reshape(B, N, 3, H, Dh).permute(2, 0, 3, 1, 4)  # (3,B,H,N,Dh)
                q, k, v = qkv[0], qkv[1], qkv[2]
                scale = attn_mod.scale
            else:
                q, k, v = attn_mod._make_qkv(x)  # (B,H,N,Dh)
                Dh = q.shape[-1]
                scale = attn_mod.scale if hasattr(attn_mod, "scale") else (Dh ** -0.5)

            # ---- 统一 K：保证 Q/K/V 都输出同一个 K ----
            Kkeep = min(top_k, N, Dh)

            # 1) QK^T
            attn_logits = (q * scale) @ k.transpose(-2, -1)     # (B,H,N,N)
            attn_mean = attn_logits.mean(dim=0)                 # (H,N,N)
            svals_qk = torch.zeros(q.shape[1], Kkeep, device=attn_mean.device)
            for h in range(q.shape[1]):
                sv = torch.linalg.svdvals(attn_mean[h].to(torch.float32))
                svals_qk[h] = sv[:Kkeep]

            # 2) Q/K/V token matrix
            svals_q = compute_singular_values_from_tokens(q, top_k=Kkeep, batch_agg="mean")  # (H,K)
            svals_k = compute_singular_values_from_tokens(k, top_k=Kkeep, batch_agg="mean")
            svals_v = compute_singular_values_from_tokens(v, top_k=Kkeep, batch_agg="mean")

            qk_list.append(svals_qk)
            q_list.append(svals_q)
            k_list.append(svals_k)
            v_list.append(svals_v)

            batches_done += 1
            if batches_done >= num_batches_spec:
                break

    handle.remove()

    if len(qk_list) == 0:
        raise RuntimeError("在谱分析时没有拿到任何 batch 的 attention。")

    # 平均多个 batch（如果 num_batches_spec>1）
    s_qk = torch.stack(qk_list, dim=0).mean(dim=0)  # (H,K)
    s_q  = torch.stack(q_list,  dim=0).mean(dim=0)
    s_k  = torch.stack(k_list,  dim=0).mean(dim=0)
    s_v  = torch.stack(v_list,  dim=0).mean(dim=0)

    return s_qk, s_q, s_k, s_v

def run_small_spectrum_experiment(
    attn_type: str,
    total_epochs: int,
    block_idx: int,
    top_k: int,
    num_batches_spec: int,
    dataset_name: str,
    model_name: str,
    data_dir: str,
    img_size: int,
    batch_size: int,
    num_workers: int,
    lr: float,
    weight_decay: float,
    rank_q: int,
    rank_k: int,
    rank_v: int,
    device: str,
    mlp_ratio: float = 2.0,
    mlp_on: str="qkv",
    sinter_A: float = 5e-5,
    sinter_omega: float = 1e4,
):
    """
    小规模训练 + 谱分析（每个 epoch 做一次）：

      - attn_type:
          "mha" / "tpa" / "nonlinear_tpa" / "headwise_nonlinear_tpa" / "sinter_tpa"

      - 每个 epoch：
          1) train_one_epoch
          2) eval_one_epoch
          3) 在指定 block 上做谱分析，拿到 top-K singular values

      - 返回:
          model, hist = {
            "train_loss_curve": [...],
            "val_loss_curve":   [...],
            "train_acc_curve":  [...],
            "val_acc_curve":    [...],
            "best_val_acc":     float,

            # 兼容旧逻辑：默认等同于 qk 的谱 (E,H,K)
            "svals_epochs":     Tensor (E, H, K),

            # 新增：分别对 qk / q / k / v 的谱 (E,H,K)
            "svals_qk_epochs":  Tensor (E, H, K),
            "svals_q_epochs":   Tensor (E, H, K) 或 None,
            "svals_k_epochs":   Tensor (E, H, K) 或 None,
            "svals_v_epochs":   Tensor (E, H, K) 或 None,

            "total_params":     int,
            "kv_cost":          float,
            "num_heads":        int,
            "head_dim":         int,
            "mlp_ratio_used":   float 或 None,
          }
    """
    # ========= 数据 =========
    train_loader, val_loader, num_classes = get_loaders(
        dataset_name=dataset_name,
        data_dir=data_dir,
        batch_size=batch_size,
        img_size=img_size,
        num_workers=num_workers,
    )

    # ========= 构建模型 =========
    pretrained = False

    if attn_type == "mha":
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="mha",
        ).to(device)

    elif attn_type == "tpa":
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
        ).to(device)

    elif attn_type == "nonlinear_tpa":
        if mlp_ratio is None:
            mlp_ratio = 1.0
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="nonlinear_tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            nonlinear_mlp_hidden_ratio=mlp_ratio,
            nonlinear_mlp_on = mlp_on,
        ).to(device)

    elif attn_type == "headwise_nonlinear_tpa":
        if mlp_ratio is None:
            mlp_ratio = 1.0
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="headwise_nonlinear_tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            nonlinear_mlp_hidden_ratio=mlp_ratio,
            nonlinear_mlp_on = mlp_on,
        ).to(device)

    elif attn_type == "sinter_tpa":
        if mlp_ratio is None:
            mlp_ratio = 1.0
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="sinter_tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            nonlinear_mlp_hidden_ratio=mlp_ratio,
            sinter_A=sinter_A,
            sinter_omega=sinter_omega,
        ).to(device)

    else:
        raise ValueError(
            f"Unknown attn_type: {attn_type}. "
            f"Expected one of ['mha','tpa','nonlinear_tpa','headwise_nonlinear_tpa','sinter_tpa']"
        )

    total_params = sum(p.numel() for p in model.parameters())

    # ========= KV cost（归一到 MHA=1）=========
    first_attn = model.vit.blocks[0].attn
    if hasattr(first_attn, "dim"):
        dim_attn = first_attn.dim
    else:
        dim_attn = first_attn.qkv.in_features

    num_heads = first_attn.num_heads
    head_dim = dim_attn // num_heads

    kv_mha = 2 * num_heads * head_dim  # = 2 * dim

    if attn_type == "mha":
        kv_cost = 1.0
    else:
        # TPA 系列（含 nonlinear/headwise/sinter）KV 估算公式和你之前一致
        kv_tpa = (rank_k + rank_v) * (num_heads + head_dim)
        kv_cost = kv_tpa / kv_mha

    # ========= 训练相关 =========
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    # 每个 epoch 一份谱： (H,K)
    svals_qk_epochs_list = []
    svals_q_epochs_list  = []
    svals_k_epochs_list  = []
    svals_v_epochs_list  = []

    scaler = None  # 先保持不开 AMP（跟你现在谱实验一致）

    for epoch in range(total_epochs):
        # ---- train ----
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
            scaler=scaler,
        )

        # ---- val ----
        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[{attn_type}] Epoch {epoch+1}/{total_epochs} | "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

        # ---- spectrum ----
        out = _collect_block_svals_one_epoch(
            model=model,
            val_loader=val_loader,
            block_idx=block_idx,
            top_k=top_k,
            num_batches_spec=num_batches_spec,
            device=device,
        )

        # 兼容两种返回：
        # 1) 新版： (svals_qk, svals_q, svals_k, svals_v)
        # 2) 旧版： svals_qk
        if isinstance(out, (tuple, list)) and len(out) == 4:
            svals_qk_epoch, svals_q_epoch, svals_k_epoch, svals_v_epoch = out
        else:
            svals_qk_epoch = out
            svals_q_epoch  = None
            svals_k_epoch  = None
            svals_v_epoch  = None

        svals_qk_epochs_list.append(svals_qk_epoch.detach().cpu())

        if svals_q_epoch is not None:
            svals_q_epochs_list.append(svals_q_epoch.detach().cpu())
        if svals_k_epoch is not None:
            svals_k_epochs_list.append(svals_k_epoch.detach().cpu())
        if svals_v_epoch is not None:
            svals_v_epochs_list.append(svals_v_epoch.detach().cpu())

    # ---- stack over epochs: (E,H,K) ----
    svals_qk_epochs = torch.stack(svals_qk_epochs_list, dim=0)  # (E,H,K)

    svals_q_epochs = torch.stack(svals_q_epochs_list, dim=0) if len(svals_q_epochs_list) > 0 else None
    svals_k_epochs = torch.stack(svals_k_epochs_list, dim=0) if len(svals_k_epochs_list) > 0 else None
    svals_v_epochs = torch.stack(svals_v_epochs_list, dim=0) if len(svals_v_epochs_list) > 0 else None

    # 兼容你后续 compute_effective_rank(svals_epochs)
    svals_epochs = svals_qk_epochs

    hist = {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve":   val_loss_curve,
        "train_acc_curve":  train_acc_curve,
        "val_acc_curve":    val_acc_curve,
        "best_val_acc":     best_val_acc,

        "svals_epochs":     svals_epochs,

        "svals_qk_epochs":  svals_qk_epochs,
        "svals_q_epochs":   svals_q_epochs,
        "svals_k_epochs":   svals_k_epochs,
        "svals_v_epochs":   svals_v_epochs,

        "total_params":     total_params,
        "kv_cost":          kv_cost,
        "num_heads":        num_heads,
        "head_dim":         head_dim,
        "mlp_ratio_used":   mlp_ratio,
        "optimizer_state":  optimizer.state_dict(),
        "last_epoch": total_epochs - 1,
    }

    return model, hist




In [6]:
def train_model_MHA(
    model: nn.Module,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    total_epochs: int,
    scaler=None,
):
    """
    适用于 MHA（或任何普通单阶段模型）的训练函数。

    行为：
      - 连续跑 total_epochs 轮
      - 每轮：
          - 调用 train_one_epoch / eval_one_epoch
          - 记录 train/val loss & acc
      - 返回一个 dict，字段和其他 train_model_* 保持一致：
          {
            "train_loss_curve": [...],
            "val_loss_curve":   [...],
            "train_acc_curve":  [...],
            "val_acc_curve":    [...],
            "best_val_acc":     float,
          }
    """
    model.to(device)

    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    use_amp = scaler is not None  # 预留，和其他函数保持一致

    for epoch in range(total_epochs):
        # 一轮训练
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
            scaler=scaler,
        )

        # 一轮验证
        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[MHA] Epoch {epoch+1}/{total_epochs} | "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

    return {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve":   val_loss_curve,
        "train_acc_curve":  train_acc_curve,
        "val_acc_curve":    val_acc_curve,
        "best_val_acc":     best_val_acc,
    }


TPA module

In [7]:
class ContextualCPAttention(nn.Module):
    """
    CP-style Tensor Product Attention，支持两种模式：
    - contextA = True :
        A_q(x), A_k(x), A_v(x) 是 token 因子 (B, N, R) —— contextual A
        B_q, B_k, B_v 是 head 因子 (R, H, D) —— non-contextual B
        Q = sum_r A_q(x)[r] * B_q[r]

    - contextA = False :
        A_q, A_k, A_v 是全局 token 因子 (R,) —— non-contextual A
        B_q(x), B_k(x), B_v(x) 是 head 因子 (B, N, R, H, D) —— contextual B
        Q = sum_r A_q[r] * B_q(x)[r]
    接口和 timm 的 Attention 一致：forward(x, attn_mask=None) -> (B, N, C)
    """
    def __init__(
        self,
        dim,
        num_heads: int = 8,
        rank_q: int = 16,
        rank_k: int = 16,
        rank_v: int = 16,
        qkv_bias: bool = True,
        qk_scale: float = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        contextA: bool = True,   # True: contextual A, False: contextual B
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.rank_q = rank_q
        self.rank_k = rank_k
        self.rank_v = rank_v
        self.contextA = contextA

        # ===== 模式一：contextual A, non-contextual B =====
        if self.contextA:
            # token 因子 A_q(x), A_k(x), A_v(x) : (B, N, R)
            self.A_q = nn.Linear(dim, rank_q, bias=qkv_bias)
            self.A_k = nn.Linear(dim, rank_k, bias=qkv_bias)
            self.A_v = nn.Linear(dim, rank_v, bias=qkv_bias)

            # head 因子 B_q, B_k, B_v : (R, H, D) 全局参数
            self.B_q = nn.Parameter(torch.empty(rank_q, num_heads, self.head_dim))
            self.B_k = nn.Parameter(torch.empty(rank_k, num_heads, self.head_dim))
            self.B_v = nn.Parameter(torch.empty(rank_v, num_heads, self.head_dim))

            nn.init.xavier_uniform_(self.B_q.view(rank_q, -1))
            nn.init.xavier_uniform_(self.B_k.view(rank_k, -1))
            nn.init.xavier_uniform_(self.B_v.view(rank_v, -1))

        # ===== 模式二：non-contextual A, contextual B =====
        else:
            # token 因子 A_q, A_k, A_v : (R,) 全局参数
            self.A_q = nn.Parameter(torch.empty(rank_q))
            self.A_k = nn.Parameter(torch.empty(rank_k))
            self.A_v = nn.Parameter(torch.empty(rank_v))

            nn.init.normal_(self.A_q, std=0.02)
            nn.init.normal_(self.A_k, std=0.02)
            nn.init.normal_(self.A_v, std=0.02)

            # head 因子 B_q(x), B_k(x), B_v(x) : (B, N, R, H, D)
            self.B_q = nn.Linear(dim, rank_q * num_heads * self.head_dim, bias=qkv_bias)
            self.B_k = nn.Linear(dim, rank_k * num_heads * self.head_dim, bias=qkv_bias)
            self.B_v = nn.Linear(dim, rank_v * num_heads * self.head_dim, bias=qkv_bias)

        self.scale = qk_scale or (self.head_dim ** -0.5)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def _make_qkv(self, x):
        """
        x: (B, N, C)
        return q, k, v: (B, H, N, D)
        """
        B, N, C = x.shape

        if self.contextA:
            # ---------- contextual A, non-contextual B ----------
            q_tok = self.A_q(x)   # (B, N, R_q)
            k_tok = self.A_k(x)   # (B, N, R_k)
            v_tok = self.A_v(x)   # (B, N, R_v)

            # 外积重构 Q/K/V: bnr, r h d -> b n h d
            q = torch.einsum("bnr,rhd->bnhd", q_tok, self.B_q)
            k = torch.einsum("bnr,rhd->bnhd", k_tok, self.B_k)
            v = torch.einsum("bnr,rhd->bnhd", v_tok, self.B_v)

        else:
            # ---------- non-contextual A, contextual B ----------
            # A: (R,) -> (1,1,R) -> (B,N,R)
            q_tok = self.A_q.view(1, 1, self.rank_q).expand(B, N, -1)  # (B,N,R_q)
            k_tok = self.A_k.view(1, 1, self.rank_k).expand(B, N, -1)  # (B,N,R_k)
            v_tok = self.A_v.view(1, 1, self.rank_v).expand(B, N, -1)  # (B,N,R_v)

            # B(x): (B,N,R,H,D)
            Bq = self.B_q(x).view(B, N, self.rank_q, self.num_heads, self.head_dim)
            Bk = self.B_k(x).view(B, N, self.rank_k, self.num_heads, self.head_dim)
            Bv = self.B_v(x).view(B, N, self.rank_v, self.num_heads, self.head_dim)

            # bnr, bnrhd -> bnhd
            q = torch.einsum("bnr,bnrhd->bnhd", q_tok, Bq)
            k = torch.einsum("bnr,bnrhd->bnhd", k_tok, Bk)
            v = torch.einsum("bnr,bnrhd->bnhd", v_tok, Bv)

        # reshape 成 (B, H, N, D)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        return q, k, v

    def forward(self, x, attn_mask=None):
        """
        x: (B, N, C)
        attn_mask: 兼容 timm 接口，一般 ViT 不会用到
        """
        B, N, C = x.shape

        q, k, v = self._make_qkv(x)  # (B, H, N, D)

        attn = (q * self.scale) @ k.transpose(-2, -1)  # (B, H, N, N)

        if attn_mask is not None:
            attn = attn + attn_mask

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v  # (B, H, N, D)
        out = out.transpose(1, 2).reshape(B, N, C)  # (B, N, C)

        out = self.proj(out)
        out = self.proj_drop(out)
        return out


In [8]:
class TPAAttention(nn.Module):
    """
    True Tensor Product Attention with separate ranks R_q, R_k, R_v,
    按原文公式 100% 走「两向量外积」版本：

      对每个 token t：
        A_Q(x_t) ∈ R^{R_q × h},  B_Q(x_t) ∈ R^{R_q × d_h}
        Q_t = (1 / R_q) * A_Q(x_t)^T B_Q(x_t) ∈ R^{h × d_h}

      K, V 同理。
    """

    def __init__(
        self,
        dim,
        num_heads: int = 8,
        rank_q: int = 16,
        rank_k: int = 2,
        rank_v: int = 2,
        qkv_bias: bool = True,
        qk_scale: float = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.rank_q = rank_q
        self.rank_k = rank_k
        self.rank_v = rank_v

        # ========== A(x): head-dim factors, shape (R * h) ==========
        # 对应论文里的 A_Q(xt) ∈ R^{R_q×h} 展平之后的线性映射
        self.A_q = nn.Linear(dim, rank_q * num_heads, bias=qkv_bias)
        self.A_k = nn.Linear(dim, rank_k * num_heads, bias=qkv_bias)
        self.A_v = nn.Linear(dim, rank_v * num_heads, bias=qkv_bias)

        # ========== B(x): token-dim factors, shape (R * d_h) ==========
        # 对应论文里的 B_Q(xt) ∈ R^{R_q×d_h} 展平之后的线性映射
        self.B_q = nn.Linear(dim, rank_q * self.head_dim, bias=qkv_bias)
        self.B_k = nn.Linear(dim, rank_k * self.head_dim, bias=qkv_bias)
        self.B_v = nn.Linear(dim, rank_v * self.head_dim, bias=qkv_bias)

        # 缩放和输出
        self.scale = qk_scale or (self.head_dim ** -0.5)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def _make_qkv(self, x):
        """
        x: (B, N, C = dim)

        返回:
          q, k, v: (B, H, N, D_h)
        """
        B, N, C = x.shape
        H, Dh = self.num_heads, self.head_dim

        # ---------- queries ----------
        Aq = self.A_q(x).view(B, N, self.rank_q, H)      # (B,N,R_q,H)
        Bq = self.B_q(x).view(B, N, self.rank_q, Dh)     # (B,N,R_q,D_h)
        # Qt: (B,N,H,D_h) = (1/R_q) * sum_r a_r ⊗ b_r
        Q = torch.einsum("bnrh,bnrd->bnhd", Aq, Bq) / float(self.rank_q)

        # ---------- keys ----------
        Ak = self.A_k(x).view(B, N, self.rank_k, H)      # (B,N,R_k,H)
        Bk = self.B_k(x).view(B, N, self.rank_k, Dh)     # (B,N,R_k,D_h)
        K = torch.einsum("bnrh,bnrd->bnhd", Ak, Bk) / float(self.rank_k)

        # ---------- values ----------
        Av = self.A_v(x).view(B, N, self.rank_v, H)      # (B,N,R_v,H)
        Bv = self.B_v(x).view(B, N, self.rank_v, Dh)     # (B,N,R_v,D_h)
        V = torch.einsum("bnrh,bnrd->bnhd", Av, Bv) / float(self.rank_v)

        # 现在 Q,K,V 形状是 (B, N, H, D_h)，需要转成 (B, H, N, D_h)
        Q = Q.permute(0, 2, 1, 3).contiguous()
        K = K.permute(0, 2, 1, 3).contiguous()
        V = V.permute(0, 2, 1, 3).contiguous()

        return Q, K, V

    def forward(self, x, attn_mask=None):
        """
        x: (B, N, C)
        返回: (B, N, C)
        """
        B, N, C = x.shape
        assert C == self.dim

        q, k, v = self._make_qkv(x)        # (B,H,N,D_h)

        attn = (q * self.scale) @ k.transpose(-2, -1)  # (B,H,N,N)
        if attn_mask is not None:
            attn = attn + attn_mask
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v                     # (B,H,N,D_h)
        out = out.transpose(1, 2).reshape(B, N, C)  # (B,N,C)

        out = self.proj(out)
        out = self.proj_drop(out)
        return out


def train_model_TPA(
    model: nn.Module,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    total_epochs: int,
    scaler=None,
):
    """
    适用于 TPA（或普通单阶段模型）的训练函数。

    行为：
      - 连续跑 total_epochs 轮
      - 每轮：
          - 调用 train_one_epoch / eval_one_epoch
          - 记录 train/val loss & acc
      - 返回一个 dict，字段和 train_model_two_stage_tucker 保持一致：
          {
            "train_loss_curve": [...],
            "val_loss_curve":   [...],
            "train_acc_curve":  [...],
            "val_acc_curve":    [...],
            "best_val_acc":     float,
          }
    """
    model.to(device)

    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    use_amp = scaler is not None

    for epoch in range(total_epochs):
        # 一轮训练
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
            scaler=scaler,
        )

        # 一轮验证
        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[TPA] Epoch {epoch+1}/{total_epochs} | "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

    return {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve":   val_loss_curve,
        "train_acc_curve":  train_acc_curve,
        "val_acc_curve":    val_acc_curve,
        "best_val_acc":     best_val_acc,
    }


In [9]:

import torch
import torch.nn as nn

class NonlinearTPAAttention(nn.Module):
    """
    True TPA + shared MLP 非线性（可选作用于 Q/K/V 的任意组合）
      Q_lin = (1/R_q) * A_Q(x)^T B_Q(x)
      Q = f(Q_lin)

    通过 mlp_on 控制对哪些分支加 MLP：
      - "qkv" / "qk" / "kv" / "q" / "k" / "v" / "none"
    """

    def __init__(
        self,
        dim,
        num_heads: int = 8,
        rank_q: int = 16,
        rank_k: int = 2,
        rank_v: int = 2,
        qkv_bias: bool = True,
        qk_scale: float = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        mlp_hidden_ratio: float = 1.0,
        mlp_on: str = "qkv",   # ✅ 新增
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.rank_q = rank_q
        self.rank_k = rank_k
        self.rank_v = rank_v

        # ====== A(x) / B(x) ======
        self.A_q = nn.Linear(dim, rank_q * num_heads, bias=qkv_bias)
        self.A_k = nn.Linear(dim, rank_k * num_heads, bias=qkv_bias)
        self.A_v = nn.Linear(dim, rank_v * num_heads, bias=qkv_bias)

        self.B_q = nn.Linear(dim, rank_q * self.head_dim, bias=qkv_bias)
        self.B_k = nn.Linear(dim, rank_k * self.head_dim, bias=qkv_bias)
        self.B_v = nn.Linear(dim, rank_v * self.head_dim, bias=qkv_bias)

        # -------- MLP 选择逻辑（和你的 headwise 版本一致的风格）--------
        mlp_on = "none" if mlp_on is None else str(mlp_on).lower().strip()


        self.kv_shared = (mlp_on == "kv_shared")

        allowed = set("qkv")
        if mlp_on in ("", "none", "null", "no", "false", "0"):
            mlp_set = set()

        elif self.kv_shared:
            mlp_set = set("kv")

        else:
            mlp_set = set(mlp_on)
            if not mlp_set.issubset(allowed):
                raise ValueError(f"mlp_on must be subset of 'qkv' or 'none', got: {mlp_on}")
        self.mlp_on = mlp_set

        hidden_dim = int(self.head_dim * mlp_hidden_ratio)

        def make_mlp():
            return nn.Sequential(
                nn.Linear(self.head_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, self.head_dim),
            )

        # ✅ shared MLP：没选中的分支用 Identity，确保“数值不变”
        self.q_mlp = make_mlp() if "q" in self.mlp_on else nn.Identity()
        if self.kv_shared and ("k" in self.mlp_on) and ("v" in self.mlp_on):
            kv_mlp = make_mlp()
            self.k_mlp = kv_mlp
            self.v_mlp = kv_mlp
        else:
            self.k_mlp = make_mlp() if "k" in self.mlp_on else nn.Identity()
            self.v_mlp = make_mlp() if "v" in self.mlp_on else nn.Identity()

        # attention 输出
        self.scale = qk_scale or (self.head_dim ** -0.5)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def _make_qkv(self, x):
        B, N, C = x.shape
        H, Dh = self.num_heads, self.head_dim

        # ---------- linear TPA ----------
        Aq = self.A_q(x).view(B, N, self.rank_q, H)
        Bq = self.B_q(x).view(B, N, self.rank_q, Dh)
        Q = torch.einsum("bnrh,bnrd->bnhd", Aq, Bq) / float(self.rank_q)

        Ak = self.A_k(x).view(B, N, self.rank_k, H)
        Bk = self.B_k(x).view(B, N, self.rank_k, Dh)
        K = torch.einsum("bnrh,bnrd->bnhd", Ak, Bk) / float(self.rank_k)

        Av = self.A_v(x).view(B, N, self.rank_v, H)
        Bv = self.B_v(x).view(B, N, self.rank_v, Dh)
        V = torch.einsum("bnrh,bnrd->bnhd", Av, Bv) / float(self.rank_v)

        Q = Q.permute(0, 2, 1, 3).contiguous()  # (B,H,N,Dh)
        K = K.permute(0, 2, 1, 3).contiguous()
        V = V.permute(0, 2, 1, 3).contiguous()

        # ---------- optional nonlinearity ----------
        Q = self.q_mlp(Q)
        K = self.k_mlp(K)
        V = self.v_mlp(V)

        return Q, K, V

    def forward(self, x, attn_mask=None):
        B, N, C = x.shape
        assert C == self.dim

        q, k, v = self._make_qkv(x)
        attn = (q * self.scale) @ k.transpose(-2, -1)
        if attn_mask is not None:
            attn = attn + attn_mask
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v
        out = out.transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

    


    
def train_model_nonlinear_TPA(
    model: nn.Module,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    total_epochs: int,
    scaler=None,
):
    """
    适用于 NonlinearTPAAttention（或其它单阶段模型）的训练函数。

    行为和 train_model_TPA 一致，只是打印前缀换成 [NonlinearTPA]，
    方便你后面对比两条曲线。
    """
    model.to(device)

    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    use_amp = scaler is not None  # 这里先保留这个变量，方便以后需要用

    for epoch in range(total_epochs):
        # 一轮训练
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
            scaler=scaler,
        )

        # 一轮验证
        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[NonlinearTPA] Epoch {epoch+1}/{total_epochs} | "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

    return {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve":   val_loss_curve,
        "train_acc_curve":  train_acc_curve,
        "val_acc_curve":    val_acc_curve,
        "best_val_acc":     best_val_acc,
    }


In [10]:
import torch
import torch.nn as nn


class HeadwiseNonlinearTPAAttention(nn.Module):
    """
    True TPA + 每个 head 独立的 MLP 非线性（可选作用于 Q/K/V 的任意组合）

    通过 mlp_on 控制对哪些分支加 MLP：
      - "qkv" / "qk" / "kv" / "q" / "k" / "v" / "none"
      - 额外支持："kv_shared"
          含义：每个 head 一个 MLP，但该 head 内 K/V 共享同一个 MLP
    """

    def __init__(
        self,
        dim,
        num_heads: int = 8,
        rank_q: int = 16,
        rank_k: int = 2,
        rank_v: int = 2,
        qkv_bias: bool = True,
        qk_scale: float = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        mlp_hidden_ratio: float = 1.0,
        mlp_on: str = "qkv",
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.rank_q = rank_q
        self.rank_k = rank_k
        self.rank_v = rank_v

        # -------- A(x) / B(x) --------
        self.A_q = nn.Linear(dim, rank_q * num_heads, bias=qkv_bias)
        self.A_k = nn.Linear(dim, rank_k * num_heads, bias=qkv_bias)
        self.A_v = nn.Linear(dim, rank_v * num_heads, bias=qkv_bias)

        self.B_q = nn.Linear(dim, rank_q * self.head_dim, bias=qkv_bias)
        self.B_k = nn.Linear(dim, rank_k * self.head_dim, bias=qkv_bias)
        self.B_v = nn.Linear(dim, rank_v * self.head_dim, bias=qkv_bias)

        # -------- MLP 选择逻辑（新增 kv_shared）--------
        mlp_on = "none" if mlp_on is None else str(mlp_on).lower().strip()
        self.kv_shared = (mlp_on == "kv_shared")

        allowed = set("qkv")
        if mlp_on in ("", "none", "null", "no", "false", "0"):
            mlp_set = set()
        elif self.kv_shared:
            mlp_set = set("kv")  # => {"k","v"}
        else:
            mlp_set = set(mlp_on)
            if not mlp_set.issubset(allowed):
                raise ValueError(f"mlp_on must be subset of 'qkv' or 'none' (or 'kv_shared'), got: {mlp_on}")
        self.mlp_on = mlp_set

        hidden_dim = max(1, int(self.head_dim * mlp_hidden_ratio))

        def make_mlp():
            return nn.Sequential(
                nn.Linear(self.head_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, self.head_dim),
            )

        # -------- 每个 head 一套 MLP；没选中的分支用 Identity --------
        # Q
        if "q" in self.mlp_on:
            self.q_mlps = nn.ModuleList([make_mlp() for _ in range(num_heads)])
        else:
            self.q_mlps = nn.ModuleList([nn.Identity() for _ in range(num_heads)])

        # K/V
        if self.kv_shared:
            # 关键：只注册一次 kv_mlps，避免重复参数引用
            self.kv_mlps = nn.ModuleList([make_mlp() for _ in range(num_heads)])
            self.k_mlps = None
            self.v_mlps = None
        else:
            self.kv_mlps = None
            if "k" in self.mlp_on:
                self.k_mlps = nn.ModuleList([make_mlp() for _ in range(num_heads)])
            else:
                self.k_mlps = nn.ModuleList([nn.Identity() for _ in range(num_heads)])

            if "v" in self.mlp_on:
                self.v_mlps = nn.ModuleList([make_mlp() for _ in range(num_heads)])
            else:
                self.v_mlps = nn.ModuleList([nn.Identity() for _ in range(num_heads)])

        # -------- attention 输出部分 --------
        self.scale = qk_scale or (self.head_dim ** -0.5)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def _apply_headwise_mlp(self, x_bhnd, mlps: nn.ModuleList):
        """
        x_bhnd: (B, H, N, D_h)
        对每个 head 分别过对应的 mlp
        """
        B, H, N, Dh = x_bhnd.shape
        outs = []
        for h in range(H):
            outs.append(mlps[h](x_bhnd[:, h, :, :]))  # (B,N,Dh)
        return torch.stack(outs, dim=1)  # (B,H,N,Dh)

    def _make_qkv(self, x):
        B, N, C = x.shape
        H, Dh = self.num_heads, self.head_dim

        Aq = self.A_q(x).view(B, N, self.rank_q, H)
        Bq = self.B_q(x).view(B, N, self.rank_q, Dh)
        Q = torch.einsum("bnrh,bnrd->bnhd", Aq, Bq) / float(self.rank_q)

        Ak = self.A_k(x).view(B, N, self.rank_k, H)
        Bk = self.B_k(x).view(B, N, self.rank_k, Dh)
        K = torch.einsum("bnrh,bnrd->bnhd", Ak, Bk) / float(self.rank_k)

        Av = self.A_v(x).view(B, N, self.rank_v, H)
        Bv = self.B_v(x).view(B, N, self.rank_v, Dh)
        V = torch.einsum("bnrh,bnrd->bnhd", Av, Bv) / float(self.rank_v)

        Q = Q.permute(0, 2, 1, 3).contiguous()  # (B,H,N,Dh)
        K = K.permute(0, 2, 1, 3).contiguous()
        V = V.permute(0, 2, 1, 3).contiguous()

        # head-wise 非线性
        Q = self._apply_headwise_mlp(Q, self.q_mlps)

        if self.kv_shared:
            K = self._apply_headwise_mlp(K, self.kv_mlps)
            V = self._apply_headwise_mlp(V, self.kv_mlps)
        else:
            K = self._apply_headwise_mlp(K, self.k_mlps)
            V = self._apply_headwise_mlp(V, self.v_mlps)

        return Q, K, V

    def forward(self, x, attn_mask=None):
        B, N, C = x.shape
        assert C == self.dim

        q, k, v = self._make_qkv(x)
        attn = (q * self.scale) @ k.transpose(-2, -1)
        if attn_mask is not None:
            attn = attn + attn_mask
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v
        out = out.transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


In [11]:



class Sinter(nn.Module):
    """
    Sinter 激活函数：
        Sinter(x) = x + A * x * sin(omega * x)
                  = x * (1 + A * sin(omega * x))

    A 和 omega 按 LoRAN 论文给的量级设置为默认值：
      - A 很小（1e-5 量级），保证整体接近恒等映射；
      - omega 很大（1e4 量级），在输入空间引入高频微扰。
    """

    def __init__(self, A: float = 5e-5, omega: float = 1e4):
        super().__init__()
        self.A = A
        self.omega = omega

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: 任意形状，逐元素作用
        return x + self.A * x * torch.sin(self.omega * x)
        # 等价写法：return x * (1 + self.A * torch.sin(self.omega * x))


class SinterTPAAttention(nn.Module):
    """
    在原始 TPA 的基础上，加一层简单的 MLP + Sinter 非线性：
      Q_lin = (1/R_q) * A_Q(x)^T B_Q(x)
      Q = f_q(Q_lin)

    f_q / f_k / f_v 都是两层 MLP: D_h -> D_h_hidden -> D_h
    激活函数用 Sinter 而不是 GeLU。
    默认 D_h_hidden = D_h（最简单的版本）
    """

    def __init__(
        self,
        dim,
        num_heads: int = 8,
        rank_q: int = 16,
        rank_k: int = 2,
        rank_v: int = 2,
        qkv_bias: bool = True,
        qk_scale: float = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        mlp_hidden_ratio: float = 1.0,  # 隐藏层大小 = ratio * head_dim
        # Sinter 超参数
        sinter_A: float = 5e-5,
        sinter_omega: float = 1e4,
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.rank_q = rank_q
        self.rank_k = rank_k
        self.rank_v = rank_v

        # ========== A(x): head-dim factors, shape (R * h) ==========
        self.A_q = nn.Linear(dim, rank_q * num_heads, bias=qkv_bias)
        self.A_k = nn.Linear(dim, rank_k * num_heads, bias=qkv_bias)
        self.A_v = nn.Linear(dim, rank_v * num_heads, bias=qkv_bias)

        # ========== B(x): token-dim factors, shape (R * d_h) ==========
        self.B_q = nn.Linear(dim, rank_q * self.head_dim, bias=qkv_bias)
        self.B_k = nn.Linear(dim, rank_k * self.head_dim, bias=qkv_bias)
        self.B_v = nn.Linear(dim, rank_v * self.head_dim, bias=qkv_bias)

        # ---------- 在 D_h 维度上做 MLP + Sinter 非线性 ----------
        hidden_dim = int(self.head_dim * mlp_hidden_ratio)

        def make_mlp():
            return nn.Sequential(
                nn.Linear(self.head_dim, hidden_dim),
                Sinter(A=sinter_A, omega=sinter_omega),
                nn.Linear(hidden_dim, self.head_dim),
            )

        # 可以选择 Q/K/V 各自一个 MLP，也可以视需要改为共享
        self.q_mlp = make_mlp()
        self.k_mlp = make_mlp()
        self.v_mlp = make_mlp()

        # 缩放和输出
        self.scale = qk_scale or (self.head_dim ** -0.5)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def _make_qkv(self, x):
        """
        x: (B, N, C = dim)
        返回:
          q, k, v: (B, H, N, D_h)
        """
        B, N, C = x.shape
        H, Dh = self.num_heads, self.head_dim

        # ---------- linear TPA 部分 ----------
        # Q
        Aq = self.A_q(x).view(B, N, self.rank_q, H)      # (B,N,R_q,H)
        Bq = self.B_q(x).view(B, N, self.rank_q, Dh)     # (B,N,R_q,D_h)
        Q = torch.einsum("bnrh,bnrd->bnhd", Aq, Bq) / float(self.rank_q)

        # K
        Ak = self.A_k(x).view(B, N, self.rank_k, H)      # (B,N,R_k,H)
        Bk = self.B_k(x).view(B, N, self.rank_k, Dh)     # (B,N,R_k,D_h)
        K = torch.einsum("bnrh,bnrd->bnhd", Ak, Bk) / float(self.rank_k)

        # V
        Av = self.A_v(x).view(B, N, self.rank_v, H)      # (B,N,R_v,H)
        Bv = self.B_v(x).view(B, N, self.rank_v, Dh)     # (B,N,R_v,D_h)
        V = torch.einsum("bnrh,bnrd->bnhd", Av, Bv) / float(self.rank_v)

        # 现在 Q,K,V 形状是 (B, N, H, D_h)，需要转成 (B, H, N, D_h)
        Q = Q.permute(0, 2, 1, 3).contiguous()  # (B,H,N,D_h)
        K = K.permute(0, 2, 1, 3).contiguous()
        V = V.permute(0, 2, 1, 3).contiguous()

        # ---------- 在 D_h 维度加一层 MLP: f(Q_lin) ----------
        Q = self.q_mlp(Q)   # (B,H,N,D_h)
        K = self.k_mlp(K)
        V = self.v_mlp(V)

        return Q, K, V

    def forward(self, x, attn_mask=None):
        """
        x: (B, N, C)
        返回: (B, N, C)
        """
        B, N, C = x.shape
        assert C == self.dim

        q, k, v = self._make_qkv(x)        # (B,H,N,D_h)

        attn = (q * self.scale) @ k.transpose(-2, -1)  # (B,H,N,N)
        if attn_mask is not None:
            attn = attn + attn_mask
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v                     # (B,H,N,D_h)
        out = out.transpose(1, 2).reshape(B, N, C)  # (B,N,C)

        out = self.proj(out)
        out = self.proj_drop(out)
        return out
    


def train_model_sinter_TPA(
    model: nn.Module,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    total_epochs: int,
    scaler=None,
):
    """
    适用于 SinterTPAAttention（或其它单阶段模型）的训练函数。

    跟 train_model_nonlinear_TPA 一样，只是打印前缀换成 [SinterTPA]，
    方便对比曲线。
    """
    model.to(device)

    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    use_amp = scaler is not None  # 预留，之后如果要改 AMP 逻辑可以用

    for epoch in range(total_epochs):
        # 一轮训练
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
            scaler=scaler,
        )

        # 一轮验证
        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[SinterTPA] Epoch {epoch+1}/{total_epochs} | "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

    return {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve":   val_loss_curve,
        "train_acc_curve":  train_acc_curve,
        "val_acc_curve":    val_acc_curve,
        "best_val_acc":     best_val_acc,
    }


In [12]:


class ContextualTuckerTPAAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        rank_q: int = 16,
        rank_k: int = 2,
        rank_v: int = 2,
        rank_head: int = None,
        rank_channel: int = None,
        qkv_bias: bool = True,
        qk_scale: float = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ):
        super().__init__()
        assert dim % num_heads == 0
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        # -------- rank 解析（和你原来的逻辑一样）---------
        if rank_q is None:
            if rank_head is not None:
                rank_q = rank_head
            elif rank_channel is not None:
                rank_q = rank_channel
            else:
                raise ValueError("rank_q is None and no rank_head/rank_channel provided.")
        if rank_k is None:
            rank_k = rank_q
        if rank_v is None:
            rank_v = rank_q

        self.rank_q = int(rank_q)
        self.rank_k = int(rank_k)
        self.rank_v = int(rank_v)

        Rq, Rk, Rv = self.rank_q, self.rank_k, self.rank_v
        H, Dh = self.num_heads, self.head_dim

        # Stage1: 先走 CP / TPA
        self.cp_only = True

        # 缩放因子，和 TPAAttention 一致： /R
        self.scale_q_tucker = 1.0 / float(Rq)
        self.scale_k_tucker = 1.0 / float(Rk)
        self.scale_v_tucker = 1.0 / float(Rv)

        # ====== Q 分支 A/B/G ======
        # 和 TPAAttention 完全相同的映射维度：
        #   A: (B,N,R,H)  ← out_features = R*H
        #   B: (B,N,R,Dh) ← out_features = R*Dh
        self.q_A_proj = nn.Linear(dim, Rq * H,  bias=qkv_bias)
        self.q_B_proj = nn.Linear(dim, Rq * Dh, bias=qkv_bias)
        self.q_G_proj = nn.Linear(dim, Rq * Rq, bias=qkv_bias)

        # ====== K 分支 ======
        self.k_A_proj = nn.Linear(dim, Rk * H,  bias=qkv_bias)
        self.k_B_proj = nn.Linear(dim, Rk * Dh, bias=qkv_bias)
        self.k_G_proj = nn.Linear(dim, Rk * Rk, bias=qkv_bias)

        # ====== V 分支 ======
        self.v_A_proj = nn.Linear(dim, Rv * H,  bias=qkv_bias)
        self.v_B_proj = nn.Linear(dim, Rv * Dh, bias=qkv_bias)
        self.v_G_proj = nn.Linear(dim, Rv * Rv, bias=qkv_bias)

        self.scale = qk_scale or (self.head_dim ** -0.5)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self._reset_parameters()

    def _reset_parameters(self):
        """
        为了让 Stage1 尽量贴近 TPAAttention：
          - A/B 使用 nn.Linear 的默认初始化（和 TPAAttention 一样）
          - 只对 G 的线性层做特殊 init
        """
        for m in [self.q_G_proj, self.k_G_proj, self.v_G_proj]:
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
            # 初期减小 G 的影响（Stage2 用）
            m.weight.data.mul_(0.1)

    # ====== CP / TPA 分支：和 TPAAttention 完全同构 ======
    @staticmethod
    def _cp_from_factors(A, B, scale: float):
        """
        A: (B, N, R, H)
        B: (B, N, R, D_h)

        Q[b,n,h,d] = (1/R) * Σ_r A[b,n,r,h] * B[b,n,r,d]
        """
        Q = torch.einsum("bnrh,bnrd->bnhd", A, B)  # (B,N,H,Dh)
        Q = Q * scale
        Q = Q.permute(0, 2, 1, 3).contiguous()     # (B,H,N,Dh)
        return Q

    # ====== Tucker 分支：A^T G B 版本（为了兼容 Stage1 的排布） ======
    @staticmethod
    def _tucker_from_factors(A, G, B, scale: float):
        """
        A: (B, N, R, H)   对应 A(x) ∈ R^{R×H}
        G: (B, N, R, R)
        B: (B, N, R, D_h)

        Q = A^T G B ∈ R^{H×D_h}
        """
        # M = A^T G  → (B,N,H,R)
        # M[b,n,h,q] = Σ_p A[b,n,p,h] * G[b,n,p,q]
        M = torch.einsum("bnph,bnpq->bnhq", A, G)

        # Q = M B    → (B,N,H,Dh)
        # Q[b,n,h,d] = Σ_q M[b,n,h,q] * B[b,n,q,d]
        Q = torch.einsum("bnhq,bnqd->bnhd", M, B)

        Q = Q * scale
        Q = Q.permute(0, 2, 1, 3).contiguous()  # (B,H,N,Dh)
        return Q

    def _make_qkv(self, x):
        """
        x: (B, N, C)
        返回 q,k,v: (B, H, N, D_h)
        """
        B, N, C = x.shape
        H, Dh = self.num_heads, self.head_dim
        Rq, Rk, Rv = self.rank_q, self.rank_k, self.rank_v

        # ---- Q ----
        qA = self.q_A_proj(x).view(B, N, Rq, H)   # (B,N,R,H)  ✅ 与 TPAAttention 对齐
        qB = self.q_B_proj(x).view(B, N, Rq, Dh)  # (B,N,R,Dh)
        if self.cp_only:
            q = self._cp_from_factors(qA, qB, self.scale_q_tucker)
        else:
            qG = self.q_G_proj(x).view(B, N, Rq, Rq)
            q = self._tucker_from_factors(qA, qG, qB, self.scale_q_tucker)

        # ---- K ----
        kA = self.k_A_proj(x).view(B, N, Rk, H)
        kB = self.k_B_proj(x).view(B, N, Rk, Dh)
        if self.cp_only:
            k = self._cp_from_factors(kA, kB, self.scale_k_tucker)
        else:
            kG = self.k_G_proj(x).view(B, N, Rk, Rk)
            k = self._tucker_from_factors(kA, kG, kB, self.scale_k_tucker)

        # ---- V ----
        vA = self.v_A_proj(x).view(B, N, Rv, H)
        vB = self.v_B_proj(x).view(B, N, Rv, Dh)
        if self.cp_only:
            v = self._cp_from_factors(vA, vB, self.scale_v_tucker)
        else:
            vG = self.v_G_proj(x).view(B, N, Rv, Rv)
            v = self._tucker_from_factors(vA, vG, vB, self.scale_v_tucker)

        return q, k, v

    def forward(self, x, attn_mask=None):
        B, N, C = x.shape
        assert C == self.dim

        q, k, v = self._make_qkv(x)  # (B,H,N,Dh)

        attn = (q * self.scale) @ k.transpose(-2, -1)  # (B,H,N,N)
        if attn_mask is not None:
            attn = attn + attn_mask

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v  # (B,H,N,Dh)
        out = out.transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out



# ===================== 工具函数：切换模式 / 控制参数训练 =====================

def set_tucker_cp_mode(model: nn.Module, cp_only: bool):
    """
    在整个 model 里，把所有 
      的 cp_only 统一设置。

    Phase 1（TPA 阶段）: cp_only = True
    Phase 2（Tucker 阶段）: cp_only = False
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            m.cp_only = cp_only


def set_tucker_ab_requires_grad(model: nn.Module, requires_grad: bool):
    """
    控制所有 Tucker 注意力里的 A/B 投影是否参与训练。
    即：
      - q_A_proj, q_B_proj
      - k_A_proj, k_B_proj
      - v_A_proj, v_B_proj
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            for proj in [
                m.q_A_proj, m.q_B_proj,
                m.k_A_proj, m.k_B_proj,
                m.v_A_proj, m.v_B_proj,
            ]:
                for p in proj.parameters():
                    p.requires_grad = requires_grad


def set_tucker_core_requires_grad(model: nn.Module, requires_grad: bool):
    """
    控制所有 Tucker 注意力里的 G 部分是否参与训练：
      - q_G_proj, k_G_proj, v_G_proj
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            for proj in [m.q_G_proj, m.k_G_proj, m.v_G_proj]:
                for p in proj.parameters():
                    p.requires_grad = requires_grad


def init_tucker_core_as_identity(model: nn.Module):
    """
    在「Tucker 阶段开始」时调用：
      把所有 G 的线性映射初始化成恒等核 I：

        G_Q(x) ≡ I_{R_q},  G_K(x) ≡ I_{R_k},  G_V(x) ≡ I_{R_v}

    实现方式：
      - 把 q_G_proj / k_G_proj / v_G_proj 的 weight 置 0
      - 把 bias 置成 vec(I_R)
      这样对任何 x 都有 G(x)=I，从而保证：

        A(x) G(x) B(x)^T = A(x) B(x)^T

      => Phase2 初始时，函数形式与 Phase1 完全一致（真·等价 TPA）。
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            with torch.no_grad():
                for proj, R in [
                    (m.q_G_proj, m.rank_q),
                    (m.k_G_proj, m.rank_k),
                    (m.v_G_proj, m.rank_v),
                ]:
                    if proj is None:
                        continue

                    # W = 0
                    proj.weight.zero_()

                    # 偏置存在则置 0，然后写成 vec(I)
                    if proj.bias is not None:
                        proj.bias.zero_()
                        eye = torch.eye(R, device=proj.weight.device, dtype=proj.weight.dtype)
                        proj.bias.copy_(eye.reshape(-1))


# ===================== 两阶段训练函数 =====================

def train_model_two_stage_tucker(
    model: nn.Module,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    total_epochs: int,
    ab_epochs: int = 40,   # 第一阶段：只训 A/B（CP/TPA）
    scaler=None,
    is_tucker_model: bool = False,
):
    """
    通用 train 函数：

      - 对 MHA / 纯 TPA model：is_tucker_model=False
            → 普通训练 total_epochs 轮。

      - 对 Tucker-TPA：is_tucker_model=True
            → 两阶段：
                * epoch < ab_epochs:
                    cp_only = True
                    A/B 可训练，G 不训练   （函数族 = TPA）
                * epoch == ab_epochs:
                    切换到 Tucker：
                      - cp_only = False
                      - 重置 G 为恒等核 I
                      - 冻结 A/B，只训 G
                * epoch > ab_epochs:
                    继续 Tucker 阶段
    """
    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    for epoch in range(total_epochs):

        # ====== 阶段切换逻辑 ======
        if is_tucker_model:
            if epoch < ab_epochs:
                # Phase 1: CP/TPA —— 训 A/B，G 不参与训练
                set_tucker_cp_mode(model, cp_only=True)
                set_tucker_ab_requires_grad(model, True)
                set_tucker_core_requires_grad(model, False)
                phase_tag = "TPA-phase(AB)"

            elif epoch == ab_epochs:
                # Phase 2 开始：把 G 初始化为 I，切换到 Tucker 形式，只训 G
                set_tucker_cp_mode(model, cp_only=False)
                set_tucker_ab_requires_grad(model, False)
                set_tucker_core_requires_grad(model, True)
                init_tucker_core_as_identity(model)
                phase_tag = "Tucker-phase(G, init=I)"

            else:
                # Phase 2 后续：继续 Tucker，只训 G
                set_tucker_cp_mode(model, cp_only=False)
                set_tucker_ab_requires_grad(model, False)
                set_tucker_core_requires_grad(model, True)
                phase_tag = "Tucker-phase(G)"
        else:
            # 非 Tucker 模型：普通训练
            phase_tag = "normal"

        # ====== 一轮训练 & 验证 ======
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch, total_epochs, scaler
        )
        val_loss, val_acc = eval_one_epoch(
            model, val_loader, criterion, device, epoch, total_epochs
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[Epoch {epoch+1}/{total_epochs} | phase={phase_tag}] "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

    return {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve": val_loss_curve,
        "train_acc_curve": train_acc_curve,
        "val_acc_curve": val_acc_curve,
        "best_val_acc": best_val_acc,
    }




In [13]:
def set_tucker_cp_mode(model: nn.Module, cp_only: bool):
    """
    在整个 model 里，把所有 ContextualTuckerTPAAttention 的 cp_only 统一设置。

    Phase 1（TPA 阶段）: cp_only = True
    Phase 2（Tucker 阶段）: cp_only = False
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            m.cp_only = cp_only


def set_tucker_ab_requires_grad(model: nn.Module, requires_grad: bool):
    """
    控制所有 Tucker 注意力里的 A/B 投影是否参与训练：
      - q_A_proj, q_B_proj
      - k_A_proj, k_B_proj
      - v_A_proj, v_B_proj
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            for proj in [
                m.q_A_proj, m.q_B_proj,
                m.k_A_proj, m.k_B_proj,
                m.v_A_proj, m.v_B_proj,
            ]:
                for p in proj.parameters():
                    p.requires_grad = requires_grad


def set_tucker_core_requires_grad(model: nn.Module, requires_grad: bool):
    """
    控制所有 Tucker 注意力里的 G 部分是否参与训练：
      - q_G_proj, k_G_proj, v_G_proj
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            for proj in [m.q_G_proj, m.k_G_proj, m.v_G_proj]:
                for p in proj.parameters():
                    p.requires_grad = requires_grad


def init_tucker_core_as_identity(model: nn.Module):
    """
    在「Tucker 第二阶段开始」时调用：
      把所有 G 的线性映射初始化成恒等核 I：

        G_Q(x) ≡ I_{R_q},  G_K(x) ≡ I_{R_k},  G_V(x) ≡ I_{R_v}

    做法：
      - weight 全 0
      - bias 写成 vec(I_R)
    这样对任何 x 都有 G(x)=I，从而保证：
        A(x)^T G(x) B(x) = A(x)^T B(x)
      → 第二阶段一开始与第一阶段前向完全等价。
    """
    for m in model.modules():
        if isinstance(m, ContextualTuckerTPAAttention):
            with torch.no_grad():
                for proj, R in [
                    (m.q_G_proj, m.rank_q),
                    (m.k_G_proj, m.rank_k),
                    (m.v_G_proj, m.rank_v),
                ]:
                    if proj is None:
                        continue

                    # W = 0
                    proj.weight.zero_()

                    if proj.bias is not None:
                        proj.bias.zero_()
                        eye = torch.eye(
                            R,
                            device=proj.weight.device,
                            dtype=proj.weight.dtype,
                        )
                        proj.bias.copy_(eye.reshape(-1))  # vec(I_R)

def train_model_two_stage_tucker(
    model: nn.Module,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    total_epochs: int,
    ab_epochs: int = 40,   # 第一阶段：只训 A/B（CP/TPA）
    scaler=None,
    is_tucker_model: bool = False,
):
    """
    通用两阶段训练函数：

      - 对非 Tucker 模型（例如普通 MHA / 纯 TPA）：
            is_tucker_model = False
        → 完全等同于单阶段训练，跑 total_epochs 轮。

      - 对使用 ContextualTuckerTPAAttention 的模型：
            is_tucker_model = True
        → 两阶段：
            * epoch < ab_epochs:
                - cp_only = True      （纯 TPA / CP）
                - A/B 可训练
                - G 不训练
            * epoch == ab_epochs:
                - cp_only = False     （切换到 Tucker）
                - 初始化 G(x) ≡ I     （前向与上个 epoch 完全等价）
                - 冻结 A/B，只训 G
            * epoch > ab_epochs:
                - cp_only = False
                - A/B 仍冻结
                - 继续只训 G
    """
    model.to(device)

    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    for epoch in range(total_epochs):

        # ================== 阶段切换逻辑 ==================
        if is_tucker_model:
            if epoch < ab_epochs:
                # Phase 1: TPA / CP —— 训 A/B，G 不参与训练
                set_tucker_cp_mode(model, cp_only=True)
                set_tucker_ab_requires_grad(model, True)
                set_tucker_core_requires_grad(model, False)
                phase_tag = "TPA-phase(AB)"

            elif epoch == ab_epochs:
                # Phase 2 刚开始：切到 Tucker（A^T G B）
                #   - cp_only=False
                #   - G 初始化为恒等核 I
                #   - 冻结 A/B，只训 G
                set_tucker_cp_mode(model, cp_only=False)
                init_tucker_core_as_identity(model)
                set_tucker_ab_requires_grad(model, False)
                set_tucker_core_requires_grad(model, True)
                phase_tag = "Tucker-phase(G, init=I)"

            else:
                # Phase 2 后续：继续 Tucker，只训 G
                set_tucker_cp_mode(model, cp_only=False)
                set_tucker_ab_requires_grad(model, False)
                set_tucker_core_requires_grad(model, True)
                phase_tag = "Tucker-phase(G)"
        else:
            phase_tag = "normal"

        # ================== 一轮训练 + 验证 ==================
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
            scaler=scaler,
        )

        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            epochs=total_epochs,
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[Epoch {epoch+1}/{total_epochs} | phase={phase_tag}] "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

    return {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve":   val_loss_curve,
        "train_acc_curve":  train_acc_curve,
        "val_acc_curve":    val_acc_curve,
        "best_val_acc":     best_val_acc,
    }


In [14]:
import math
import torch
import torch.nn as nn


class StaticCoreTuckerTPAAttention(nn.Module):
    """
    Tucker-style Tensor Product Attention with:
      - token-dependent A_t(x), B_t(x)
      - *static* global cores G_q, G_k, G_v (non-contextual)

    对每个 token t:
        Q_t = A_t(x) · G_q_eff · B_t(x)^T

    其中 G_q_eff 可以通过 tucker_lambda 在 diag(G) 和 full(G) 之间渐进过渡：
        G_eff = λ * diag(G_raw) + (1-λ) * G_raw
      - λ = 1 → CP/TPA 风格 (只用 diag)
      - λ = 0 → full Tucker
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        rank_head: int = 2,     # R_h
        rank_channel: int = 4,  # R_d
        qkv_bias: bool = True,
        qk_scale: float = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        progressive_g: bool = True,
        init_lambda: float = 1.0,   # 默认一开始 λ=1：CP/TPA 模式
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.rank_head = rank_head
        self.rank_channel = rank_channel

        Rh, Rd = rank_head, rank_channel
        H, Dh = num_heads, self.head_dim

        # λ ∈ [0,1] : 1 → diag, 0 → full
        self.progressive_g = progressive_g
        self.tucker_lambda = float(init_lambda)

        # Tucker 输出缩放，防止 Q/K 范数随 rank 爆炸
        self.tucker_scale = 1.0 / math.sqrt(rank_head * rank_channel)

        # ---- Q 的 A/B (token-dependent) ----
        self.q_A_proj = nn.Linear(dim, H * Rh, bias=qkv_bias)   # (B,N,H,Rh)
        self.q_B_proj = nn.Linear(dim, Dh * Rd, bias=qkv_bias)  # (B,N,Dh,Rd)
        # ---- Q 的 static core G_q ----
        self.G_q = nn.Parameter(torch.randn(Rh, Rd) * 0.02)

        # ---- K 的 A/B + static core G_k ----
        self.k_A_proj = nn.Linear(dim, H * Rh, bias=qkv_bias)
        self.k_B_proj = nn.Linear(dim, Dh * Rd, bias=qkv_bias)
        self.G_k = nn.Parameter(torch.randn(Rh, Rd) * 0.02)

        # ---- V 的 A/B + static core G_v ----
        self.v_A_proj = nn.Linear(dim, H * Rh, bias=qkv_bias)
        self.v_B_proj = nn.Linear(dim, Dh * Rd, bias=qkv_bias)
        self.G_v = nn.Parameter(torch.randn(Rh, Rd) * 0.02)

        # 标准 attention 组件
        self.scale = qk_scale or (self.head_dim ** -0.5)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self._reset_parameters()

    def _reset_parameters(self):
        # A/B 全部 xavier，bias 置 0
        for m in [
            self.q_A_proj, self.q_B_proj,
            self.k_A_proj, self.k_B_proj,
            self.v_A_proj, self.v_B_proj,
        ]:
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

        # static G 用较小 init，避免一开始爆
        for G in [self.G_q, self.G_k, self.G_v]:
            nn.init.xavier_uniform_(G)
            G.data.mul_(0.1)

    # ========== 关键：对 static G 做 λ 控制 ==========

    def _core_with_lambda(self, G_raw: torch.Tensor) -> torch.Tensor:
        """
        G_raw: (R_h, R_d)

        返回 G_eff:
          - 如果 progressive_g=False 或 lambda<=0: 返回 full G
          - 如果 lambda>=1: 只保留 diag(G_raw)
          - 否则: λ*diag(G_raw) + (1-λ)*G_raw
        """
        if (not self.progressive_g) or self.tucker_lambda <= 0.0:
            return G_raw

        lam = float(self.tucker_lambda)
        lam = max(0.0, min(1.0, lam))
        if lam == 0.0:
            return G_raw

        Rh, Rd = G_raw.shape
        k = min(Rh, Rd)

        diag_matrix = torch.zeros_like(G_raw)
        idx = torch.arange(k, device=G_raw.device)
        diag_matrix[idx, idx] = G_raw[idx, idx]

        G_eff = lam * diag_matrix + (1.0 - lam) * G_raw
        return G_eff

    @staticmethod
    def _tucker_from_factors(A, G, B, scale: float):
        """
        A: (B, N, H, R_h)
        G: (R_h, R_d)      or broadcasted to (B,N,R_h,R_d)
        B: (B, N, D_h, R_d)

        return:
          out: (B, H, N, D_h)
        """
        Bsz, N, H, Rh = A.shape
        _, _, Dh, Rd = B.shape

        if G.dim() == 2:
            # (Rh,Rd) -> (1,1,Rh,Rd) -> (B,N,Rh,Rd)
            Gb = G.view(1, 1, Rh, Rd).expand(Bsz, N, Rh, Rd)
        else:
            Gb = G  # 已经是 (B,N,Rh,Rd)

        # (B,N,H,R_h) × (B,N,R_h,R_d) -> (B,N,H,R_d)
        T = torch.einsum("bnhp,bnpq->bnhq", A, Gb)
        # (B,N,H,R_d) × (B,N,D_h,R_d) -> (B,N,H,D_h)
        Q = torch.einsum("bnhq,bndq->bnhd", T, B)
        Q = Q * scale
        Q = Q.permute(0, 2, 1, 3).contiguous()  # (B,H,N,D_h)
        return Q

    def _make_qkv(self, x):
        """
        x: (B, N, C)
        return q, k, v: (B, H, N, D_h)
        """
        Bsz, N, C = x.shape
        H, Dh = self.num_heads, self.head_dim
        Rh, Rd = self.rank_head, self.rank_channel

        # ========= Q =========
        qA = self.q_A_proj(x).view(Bsz, N, H, Rh)
        qB = self.q_B_proj(x).view(Bsz, N, Dh, Rd)
        Gq_eff = self._core_with_lambda(self.G_q)  # (Rh,Rd)
        q = self._tucker_from_factors(qA, Gq_eff, qB, self.tucker_scale)

        # ========= K =========
        kA = self.k_A_proj(x).view(Bsz, N, H, Rh)
        kB = self.k_B_proj(x).view(Bsz, N, Dh, Rd)
        Gk_eff = self._core_with_lambda(self.G_k)
        k = self._tucker_from_factors(kA, Gk_eff, kB, self.tucker_scale)

        # ========= V =========
        vA = self.v_A_proj(x).view(Bsz, N, H, Rh)
        vB = self.v_B_proj(x).view(Bsz, N, Dh, Rd)
        Gv_eff = self._core_with_lambda(self.G_v)
        v = self._tucker_from_factors(vA, Gv_eff, vB, self.tucker_scale)

        return q, k, v

    def forward(self, x, attn_mask=None):
        Bsz, N, C = x.shape
        assert C == self.dim, f"[StaticCoreTuckerTPAAttention] expected dim={self.dim}, got {C}"

        q, k, v = self._make_qkv(x)  # (B, H, N, D_h)

        attn = (q * self.scale) @ k.transpose(-2, -1)  # (B, H, N, N)
        if attn_mask is not None:
            attn = attn + attn_mask

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v  # (B,H,N,D_h)
        out = out.transpose(1, 2).reshape(Bsz, N, C)

        out = self.proj(out)
        out = self.proj_drop(out)
        return out


In [15]:
import math
import torch
import torch.nn as nn

class FocusContextualTPAAttention(nn.Module):
    """
    full + low-rank (TPA) 双分支注意力 + 空间 gate：
    - full 分支：正常 MHA
    - low-rank 分支：True TPA（两个向量的 outer product）：
        Q_t = (1/R) * sum_r a_r(x_t) ⊗ b_r(x_t)
      并支持两种 ablation：
        * contextA = True : contextual A, non-contextual B
        * contextA = False: non-contextual A, contextual B
    - gate: 用 full 分支的 attention + 2D 空间权重，决定每个 token
            更偏 full 还是偏 low-rank
    同时维护 alpha 的 running mean，方便估算 KV cache。
    """
    def __init__(
        self,
        dim,
        num_heads=8,
        rank=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        grid_size=(14, 14),  # ViT tiny: (224,16) -> (14,14)
        gamma=1.0,           # 空间邻域加权强度
        sigma=None,          # 高斯宽度
        gate_scale=1.0,      # s_norm 放大多少再过 sigmoid
        contextA=True,       # True: contextual A, False: contextual B
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim 必须能整除 num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.rank = rank
        self.contextA = contextA

        self.scale = qk_scale or self.head_dim ** -0.5

        # ===== full 分支：标准 qkv =====
        self.qkv_full = nn.Linear(dim, dim * 3, bias=qkv_bias)

        # ===== low-rank 分支：True TPA（两向量 outer product）=====
        H = self.num_heads
        Dh = self.head_dim
        R = self.rank

        if self.contextA:
            # ------- contextual A(x), non-contextual B -------
            # A_q/A_k/A_v(x): (B,N,R*H) -> (B,N,R,H)
            self.A_q = nn.Linear(dim, R * H, bias=qkv_bias)
            self.A_k = nn.Linear(dim, R * H, bias=qkv_bias)
            self.A_v = nn.Linear(dim, R * H, bias=qkv_bias)

            # B_q/B_k/B_v: (R,D_h) 全局参数（non-contextual）
            self.B_q = nn.Parameter(torch.empty(R, Dh))
            self.B_k = nn.Parameter(torch.empty(R, Dh))
            self.B_v = nn.Parameter(torch.empty(R, Dh))

            nn.init.xavier_uniform_(self.B_q)
            nn.init.xavier_uniform_(self.B_k)
            nn.init.xavier_uniform_(self.B_v)

        else:
            # ------- non-contextual A, contextual B(x) -------
            # A_q/A_k/A_v: (R,H) 全局参数
            self.A_q = nn.Parameter(torch.empty(R, H))
            self.A_k = nn.Parameter(torch.empty(R, H))
            self.A_v = nn.Parameter(torch.empty(R, H))

            nn.init.xavier_uniform_(self.A_q)
            nn.init.xavier_uniform_(self.A_k)
            nn.init.xavier_uniform_(self.A_v)

            # B_q/B_k/B_v(x): (B,N,R*D_h) -> (B,N,R,D_h)
            self.B_q = nn.Linear(dim, R * Dh, bias=qkv_bias)
            self.B_k = nn.Linear(dim, R * Dh, bias=qkv_bias)
            self.B_v = nn.Linear(dim, R * Dh, bias=qkv_bias)

        # 输出投影（和原来一样）
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # ===== 空间 gate（保持原样） =====
        self.grid_size = grid_size
        Hp, Wp = grid_size
        num_patches = Hp * Wp
        self.num_patches = num_patches
        self.gamma = gamma
        self.gate_scale = gate_scale

        # 预计算 2D 距离矩阵 -> 空间权重 w_ij
        coords = []
        for y in range(Hp):
            for x in range(Wp):
                coords.append((x, y))
        coords = torch.tensor(coords, dtype=torch.float32)  # (P,2)
        dists = torch.cdist(coords, coords, p=2)            # (P,P)

        if sigma is None:
            sigma = max(Hp, Wp) / 2.0

        w = 1.0 + gamma * torch.exp(- (dists ** 2) / (2 * sigma ** 2))  # (P,P)
        self.register_buffer("spatial_weight", w)

        self.eps = 1e-6
        self.last_alpha = None

        # 统计 alpha 的 running mean（用来估算 KV cache）
        self.register_buffer("alpha_running_sum", torch.zeros(1))
        self.register_buffer("alpha_count", torch.zeros(1))

    # === low-rank 分支：构造 Q/K/V（True TPA） ===
    def _low_rank_qkv(self, x):
        """
        x: (B, N, C)
        return q, k, v: (B, H, N, D_h)
        """
        B, N, C = x.shape
        H, Dh, R = self.num_heads, self.head_dim, self.rank

        if self.contextA:
            # ---------- contextual A(x) , non-contextual B ----------
            # A(x): (B,N,R,H)
            Aq = self.A_q(x).view(B, N, R, H)
            Ak = self.A_k(x).view(B, N, R, H)
            Av = self.A_v(x).view(B, N, R, H)

            # B: (R,D_h)
            Bq = self.B_q               # (R,D_h)
            Bk = self.B_k
            Bv = self.B_v

            # Q/K/V: (B,N,H,D_h) = (1/R) * sum_r a_r ⊗ b_r
            q = torch.einsum("bnrh,rd->bnhd", Aq, Bq) / float(R)
            k = torch.einsum("bnrh,rd->bnhd", Ak, Bk) / float(R)
            v = torch.einsum("bnrh,rd->bnhd", Av, Bv) / float(R)

        else:
            # ---------- non-contextual A , contextual B(x) ----------
            # A: (R,H)
            Aq = self.A_q              # (R,H)
            Ak = self.A_k
            Av = self.A_v

            # B(x): (B,N,R,D_h)
            Bq = self.B_q(x).view(B, N, R, Dh)
            Bk = self.B_k(x).view(B, N, R, Dh)
            Bv = self.B_v(x).view(B, N, R, Dh)

            # Q/K/V: (B,N,H,D_h) = (1/R) * sum_r a_r ⊗ b_r
            q = torch.einsum("rh,bnrd->bnhd", Aq, Bq) / float(R)
            k = torch.einsum("rh,bnrd->bnhd", Ak, Bk) / float(R)
            v = torch.einsum("rh,bnrd->bnhd", Av, Bv) / float(R)

        # 变成 (B,H,N,D_h)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)
        return q, k, v

    # === 从 full attention + 2D 位置 得到 gate alpha（保持原样） ===
    def _compute_gate(self, attn_full, gate_strength=1.0):
        """
        attn_full: (B, H, N, N)，已经 softmax
        return alpha_final: (B, N, 1)
        """
        B, H, N, _ = attn_full.shape

        # 对 head 平均: (B,N,N)
        attn_mean = attn_full.mean(dim=1)

        # u[b,i] = sum_j attn[j,i]  (被多少 token 看)
        u = attn_mean.sum(dim=1)  # (B,N)

        if N == self.num_patches + 1:
            has_cls = True
            u_cls = u[:, :1]
            u_patch = u[:, 1:]
        elif N == self.num_patches:
            has_cls = False
            u_patch = u
            u_cls = None
        else:
            # 防御式：假定前面有一些非 patch token
            has_cls = True
            extra = N - self.num_patches
            u_cls = u[:, :extra]
            u_patch = u[:, extra:]

        # 空间加权: s = u_patch @ W^T
        s_patch = torch.matmul(u_patch, self.spatial_weight.t())  # (B,P)

        mean = s_patch.mean(dim=-1, keepdim=True)
        std = s_patch.std(dim=-1, keepdim=True) + self.eps
        s_norm = (s_patch - mean) / std

        alpha_patch = torch.sigmoid(self.gate_scale * s_norm)  # (B,P)

        if has_cls:
            alpha_full = torch.ones(
                B, self.num_patches + 1, 1,
                device=attn_full.device,
                dtype=attn_full.dtype,
            )
            alpha_full[:, 0, 0] = 1.0  # cls 全 full
            alpha_full[:, 1:, 0] = alpha_patch
        else:
            alpha_full = alpha_patch.unsqueeze(-1)  # (B,P,1)

        # gate_strength: 0 -> 全部 1, 1 -> 用 alpha_full
        if gate_strength < 1.0:
            alpha_final = 1.0 - gate_strength * (1.0 - alpha_full)
        else:
            alpha_final = alpha_full

        self.last_alpha = alpha_final.detach()

        # 更新 running mean
        self.alpha_running_sum += alpha_final.detach().sum()
        self.alpha_count += torch.tensor(
            alpha_final.numel(),
            device=alpha_final.device,
            dtype=self.alpha_running_sum.dtype,
        )

        return alpha_final  # (B,N,1)

    def reset_alpha_stats(self):
        self.alpha_running_sum.zero_()
        self.alpha_count.zero_()

    def get_mean_alpha(self):
        if self.alpha_count.item() <= 0:
            return None
        return (self.alpha_running_sum / self.alpha_count).item()

    def forward(self, x, attn_mask=None, gate_strength=1.0):
        B, N, C = x.shape

        # ===== full 分支 =====
        qkv = self.qkv_full(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3,B,H,N,D)
        q_full, k_full, v_full = qkv[0], qkv[1], qkv[2]

        attn_scores_full = (q_full * self.scale) @ k_full.transpose(-2, -1)
        if attn_mask is not None:
            attn_scores_full = attn_scores_full + attn_mask
        attn_full = attn_scores_full.softmax(dim=-1)
        attn_full = self.attn_drop(attn_full)

        # gate
        alpha = self._compute_gate(attn_full, gate_strength=gate_strength)  # (B,N,1)

        out_full = attn_full @ v_full  # (B,H,N,D)
        out_full = out_full.transpose(1, 2).reshape(B, N, C)  # (B,N,C)

        # ===== low-rank 分支（True TPA）=====
        q_low, k_low, v_low = self._low_rank_qkv(x)
        attn_scores_low = (q_low * self.scale) @ k_low.transpose(-2, -1)
        if attn_mask is not None:
            attn_scores_low = attn_scores_low + attn_mask
        attn_low = attn_scores_low.softmax(dim=-1)
        attn_low = self.attn_drop(attn_low)

        out_low = attn_low @ v_low
        out_low = out_low.transpose(1, 2).reshape(B, N, C)

        # ===== gate 混合 =====
        if gate_strength <= 0.0:
            out = out_full
        else:
            out = alpha * out_full + (1.0 - alpha) * out_low

        out = self.proj(out)
        out = self.proj_drop(out)
        return out


In [16]:
# B 也 contextual 的 Focus-TPA（A、B 都是 contextual，两向量 outer product 版本）
import math
import torch
import torch.nn as nn
class FocusTPAAttention(nn.Module):
    """
    full + low-rank (TPA) 双分支注意力 + 空间 gate：
    - full 分支：正常 MHA
    - low-rank 分支：True TPA（两个向量的 outer product）：
        对每个 token t:
          A_Q(x_t) ∈ R^{R × H},  B_Q(x_t) ∈ R^{R × D_h}
          Q_t = (1/R) * sum_r A_Q(x_t)[r] ⊗ B_Q(x_t)[r] ∈ R^{H×D_h}
      K, V 同理，A、B 都是 contextual（依赖 x_t）
    - gate: 用 full 分支的 attention + 2D 空间权重，决定每个 token
            更偏 full 还是偏 low-rank
    同时维护 alpha 的 running mean，方便估算 KV cache。
    """
    def __init__(
        self,
        dim,
        num_heads=8,
        rank=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        grid_size=(14, 14),  # ViT tiny: (224,16) -> (14,14)
        gamma=1.0,           # 空间邻域加权强度
        sigma=None,          # 高斯宽度
        gate_scale=1.0,      # s_norm 放大多少再过 sigmoid
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim 必须能整除 num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.rank = rank

        self.scale = qk_scale or self.head_dim ** -0.5

        # ===== full 分支：标准 qkv =====
        self.qkv_full = nn.Linear(dim, dim * 3, bias=qkv_bias)

        # ===== low-rank 分支：True TPA（A、B 都 contextual，两个向量 outer product）=====
        H = self.num_heads
        Dh = self.head_dim
        R = self.rank

        # A_q/A_k/A_v(x): (B,N,R*H) -> (B,N,R,H)
        self.A_q = nn.Linear(dim, R * H, bias=qkv_bias)
        self.A_k = nn.Linear(dim, R * H, bias=qkv_bias)
        self.A_v = nn.Linear(dim, R * H, bias=qkv_bias)

        # B_q/B_k/B_v(x): (B,N,R*D_h) -> (B,N,R,D_h)
        self.B_q = nn.Linear(dim, R * Dh, bias=qkv_bias)
        self.B_k = nn.Linear(dim, R * Dh, bias=qkv_bias)
        self.B_v = nn.Linear(dim, R * Dh, bias=qkv_bias)

        # 输出投影
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # ===== 空间 gate =====
        self.grid_size = grid_size
        Hp, Wp = grid_size
        num_patches = Hp * Wp
        self.num_patches = num_patches
        self.gamma = gamma
        self.gate_scale = gate_scale

        # 预计算 2D 距离矩阵 -> 空间权重 w_ij
        coords = []
        for y in range(Hp):
            for x in range(Wp):
                coords.append((x, y))
        coords = torch.tensor(coords, dtype=torch.float32)  # (P,2)
        dists = torch.cdist(coords, coords, p=2)            # (P,P)

        if sigma is None:
            sigma = max(Hp, Wp) / 2.0

        w = 1.0 + gamma * torch.exp(- (dists ** 2) / (2 * sigma ** 2))  # (P,P)
        self.register_buffer("spatial_weight", w)

        self.eps = 1e-6
        self.last_alpha = None

        # 统计 alpha 的 running mean（用来估算 KV cache）
        self.register_buffer("alpha_running_sum", torch.zeros(1))
        self.register_buffer("alpha_count", torch.zeros(1))

        # 模块级 gate 强度（默认 1.0，训练时会被外部调度）
        self.gate_strength = 1.0

    # === low-rank 分支：构造 Q/K/V（原版 TPA，两向量 outer product） ===
    def _low_rank_qkv(self, x):
        """
        x: (B, N, C)
        return q, k, v: (B, H, N, D)
        """
        B, N, C = x.shape
        H, Dh, R = self.num_heads, self.head_dim, self.rank

        # ---------- Q ----------
        Aq = self.A_q(x).view(B, N, R, H)   # (B,N,R,H)
        Bq = self.B_q(x).view(B, N, R, Dh)  # (B,N,R,D_h)
        # (B,N,H,D_h) = (1/R) * sum_r a_r ⊗ b_r
        q = torch.einsum("bnrh,bnrd->bnhd", Aq, Bq) / float(R)

        # ---------- K ----------
        Ak = self.A_k(x).view(B, N, R, H)
        Bk = self.B_k(x).view(B, N, R, Dh)
        k = torch.einsum("bnrh,bnrd->bnhd", Ak, Bk) / float(R)

        # ---------- V ----------
        Av = self.A_v(x).view(B, N, R, H)
        Bv = self.B_v(x).view(B, N, R, Dh)
        v = torch.einsum("bnrh,bnrd->bnhd", Av, Bv) / float(R)

        # reshape 成 (B, H, N, D)
        q = q.permute(0, 2, 1, 3)  # (B,H,N,D)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)
        return q, k, v

    # === 从 full attention + 2D 位置 得到 gate alpha ===
    def _compute_gate(self, attn_full, gate_strength=1.0):
        """
        attn_full: (B, H, N, N)，已经 softmax
        return alpha_final: (B, N, 1)
        """
        B, H, N, _ = attn_full.shape

        # 对 head 平均: (B,N,N)
        attn_mean = attn_full.mean(dim=1)

        # u[b,i] = sum_j attn[j,i]  (被多少 token 看)
        u = attn_mean.sum(dim=1)  # (B,N)

        if N == self.num_patches + 1:
            has_cls = True
            u_cls = u[:, :1]
            u_patch = u[:, 1:]
        elif N == self.num_patches:
            has_cls = False
            u_patch = u
            u_cls = None
        else:
            # 防御式：假定前面有一些非 patch token
            has_cls = True
            extra = N - self.num_patches
            u_cls = u[:, :extra]
            u_patch = u[:, extra:]

        # 空间加权: s = u_patch @ W^T
        s_patch = torch.matmul(u_patch, self.spatial_weight.t())  # (B,P)

        mean = s_patch.mean(dim=-1, keepdim=True)
        std = s_patch.std(dim=-1, keepdim=True) + self.eps
        s_norm = (s_patch - mean) / std

        alpha_patch = torch.sigmoid(self.gate_scale * s_norm)  # (B,P)

        if has_cls:
            alpha_full = torch.ones(
                B, self.num_patches + 1, 1,
                device=attn_full.device,
                dtype=attn_full.dtype,
            )
            alpha_full[:, 0, 0] = 1.0  # cls 全 full
            alpha_full[:, 1:, 0] = alpha_patch
        else:
            alpha_full = alpha_patch.unsqueeze(-1)  # (B,P,1)

        # gate_strength: 0 -> 全部 1, 1 -> 用 alpha_full
        if gate_strength < 1.0:
            alpha_final = 1.0 - gate_strength * (1.0 - alpha_full)
        else:
            alpha_final = alpha_full

        self.last_alpha = alpha_final.detach()

        # 更新 running mean
        self.alpha_running_sum += alpha_final.detach().sum()
        self.alpha_count += torch.tensor(
            alpha_final.numel(),
            device=alpha_final.device,
            dtype=self.alpha_running_sum.dtype,
        )

        return alpha_final  # (B,N,1)

    def reset_alpha_stats(self):
        self.alpha_running_sum.zero_()
        self.alpha_count.zero_()

    def get_mean_alpha(self):
        if self.alpha_count.item() <= 0:
            return None
        return (self.alpha_running_sum / self.alpha_count).item()

    def forward(self, x, attn_mask=None, gate_strength=None):
        B, N, C = x.shape

        # 如果调用方没传 gate_strength，就用模块自己的值（由训练循环控制）
        if gate_strength is None:
            gate_strength = float(self.gate_strength)

        # ===== full 分支 =====
        qkv = self.qkv_full(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3,B,H,N,D)
        q_full, k_full, v_full = qkv[0], qkv[1], qkv[2]

        attn_scores_full = (q_full * self.scale) @ k_full.transpose(-2, -1)
        if attn_mask is not None:
            attn_scores_full = attn_scores_full + attn_mask
        attn_full = attn_scores_full.softmax(dim=-1)
        attn_full = self.attn_drop(attn_full)

        # gate
        alpha = self._compute_gate(attn_full, gate_strength=gate_strength)  # (B,N,1)

        out_full = attn_full @ v_full  # (B,H,N,D)
        out_full = out_full.transpose(1, 2).reshape(B, N, C)  # (B,N,C)

        # ===== low-rank 分支（TPA）=====
        q_low, k_low, v_low = self._low_rank_qkv(x)
        attn_scores_low = (q_low * self.scale) @ k_low.transpose(-2, -1)
        if attn_mask is not None:
            attn_scores_low = attn_scores_low + attn_mask
        attn_low = attn_scores_low.softmax(dim=-1)
        attn_low = self.attn_drop(attn_low)

        out_low = attn_low @ v_low
        out_low = out_low.transpose(1, 2).reshape(B, N, C)

        # ===== gate 混合 =====
        if gate_strength <= 0.0:
            out = out_full
        else:
            out = alpha * out_full + (1.0 - alpha) * out_low

        out = self.proj(out)
        out = self.proj_drop(out)
        return out
def set_focus_tpa_gate_strength(model: nn.Module, gate_strength: float):
    """
    在整个模型里，把所有 FocusTPAAttention 的 gate_strength 设为同一个值。

    gate_strength 含义：
      - 0.0  → out ≈ 纯 full MHA 分支
      - 1.0  → 使用空间 gate：alpha * full + (1-alpha) * low-rank
      - 中间值 → 从纯 full 逐渐过渡到 gated 混合
    """
    from inspect import isclass
    # 防止名字没导入的问题，兼容你之前定义的 FocusTPAAttention
    try:
        FocusClass = FocusTPAAttention
    except NameError:
        raise RuntimeError("请先在上面的 cell 中定义 FocusTPAAttention 类。")

    for m in model.modules():
        if isinstance(m, FocusClass):
            m.gate_strength = float(gate_strength)


def train_model_focus_tpa(
    model: nn.Module,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    total_epochs: int,
    warmup_full_epochs: int = 10,   # 阶段1：纯 full MHA
    ramp_gate_epochs: int = 10,     # 阶段2：gate_strength 线性从 0→1
    scaler=None,
):
    """
    适用于 FocusTPAAttention 的训练函数（full + low-rank True TPA + 空间 gate）。

    三个阶段：
      - Epoch < warmup_full_epochs:
          gate_strength = 0.0       → 输出完全来自 full 分支
      - warmup_full_epochs <= Epoch < warmup_full_epochs + ramp_gate_epochs:
          gate_strength 线性从 0→1 → 逐步引入 low-rank 分支
      - 之后:
          gate_strength = 1.0       → 使用空间 gate 的完整混合
    """
    model.to(device)

    train_loss_curve = []
    val_loss_curve   = []
    train_acc_curve  = []
    val_acc_curve    = []
    best_val_acc     = 0.0

    for epoch in range(total_epochs):
        # ====== 决定本 epoch 的 gate_strength & phase 标记 ======
        if epoch < warmup_full_epochs:
            gate_strength = 0.0
            phase_tag = "focus-warmup(full-only)"
        elif epoch < warmup_full_epochs + ramp_gate_epochs and ramp_gate_epochs > 0:
            # 线性从 0 → 1
            progress = (epoch - warmup_full_epochs + 1) / float(ramp_gate_epochs)
            gate_strength = max(0.0, min(1.0, progress))
            phase_tag = f"focus-ramp(g={gate_strength:.2f})"
        else:
            gate_strength = 1.0
            phase_tag = "focus-mix(full+lowrank)"

        # 把 gate_strength 写进所有 FocusTPAAttention 模块
        set_focus_tpa_gate_strength(model, gate_strength)

        # ====== 一轮训练 & 验证 ======
        # 这里要和你原来 MHA/TPA/Tucker 用的 train_one_epoch / eval_one_epoch 保持一致
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch, total_epochs, scaler
        )
        val_loss, val_acc = eval_one_epoch(
            model, val_loader, criterion, device, epoch, total_epochs
        )

        train_loss_curve.append(train_loss)
        val_loss_curve.append(val_loss)
        train_acc_curve.append(train_acc)
        val_acc_curve.append(val_acc)
        best_val_acc = max(best_val_acc, val_acc)

        print(
            f"[FocusTPA] Epoch {epoch+1}/{total_epochs} | phase={phase_tag} | "
            f"train loss: {train_loss:.4f}, train acc: {train_acc:.3f} | "
            f"val loss: {val_loss:.4f}, val acc: {val_acc:.3f}"
        )

    return {
        "train_loss_curve": train_loss_curve,
        "val_loss_curve":   val_loss_curve,
        "train_acc_curve":  train_acc_curve,
        "val_acc_curve":    val_acc_curve,
        "best_val_acc":     best_val_acc,
    }


In [17]:
from timm.models.vision_transformer import VisionTransformer
import torch.nn as nn

# 记得这几个类已经在别的文件里定义好：
# - TPAAttention
# - ContextualCPAttention   （如果你还用的话）
# - FocusTPAAttention
# - FocusContextualTPAAttention
# - ContextualTuckerTPAAttention
# - StaticCoreTuckerTPAAttention  （如果还需要静态版）


def replace_vit_attn_with_tpa_true(
    vit_model: VisionTransformer,
    rank_q: int,
    rank_k: int,
    rank_v: int,
):
    """
    把 timm ViT 的每个 block.attn 换成：论文原版 True TPA (A/B 都 contextual, 两向量外积)
    使用 TPAAttention（你之前那版实现）。
    """
    for blk in vit_model.blocks:
        old_attn = blk.attn

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = old_attn.scale if hasattr(old_attn, "scale") else None

        blk.attn = TPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
        )
def replace_vit_attn_with_nonlinear_tpa(
    vit_model,
    rank_q: int,
    rank_k: int,
    rank_v: int,
    mlp_hidden_ratio: float = 1.0,
    mlp_on: str = "qkv",   # ✅ 新增
):
    for blk in vit_model.blocks:
        old_attn = blk.attn

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = old_attn.scale if hasattr(old_attn, "scale") else None

        blk.attn = NonlinearTPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            mlp_hidden_ratio=mlp_hidden_ratio,
            mlp_on=mlp_on,   # ✅ 关键
        )

        
def replace_vit_attn_with_headwise_nonlinear_tpa(
    vit_model,
    rank_q: int,
    rank_k: int,
    rank_v: int,
    mlp_hidden_ratio: float = 1.0,
    mlp_on: str = "qkv",
):
    
    for blk in vit_model.blocks:
        old_attn = blk.attn

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = old_attn.scale if hasattr(old_attn, "scale") else None

        blk.attn = HeadwiseNonlinearTPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            mlp_hidden_ratio=mlp_hidden_ratio,
            mlp_on=mlp_on,
        )



def replace_vit_attn_with_sinter_tpa(
    vit_model: VisionTransformer,
    rank_q: int,
    rank_k: int,
    rank_v: int,
    mlp_hidden_ratio: float = 1.0,
    sinter_A: float = 5e-5,
    sinter_omega: float = 1e4,
):
    """
    把 timm ViT 的每个 block.attn 换成 SinterTPAAttention 版本的 TPA：
      Q_lin = 1/R * A^T B
      Q = MLP_Sinter(Q_lin)

    参数
    ----
    rank_q, rank_k, rank_v : TPA 的 rank
    mlp_hidden_ratio       : hidden_dim = mlp_hidden_ratio * head_dim
    sinter_A, sinter_omega : Sinter 激活的超参数
    """
    for blk in vit_model.blocks:
        old_attn = blk.attn

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = old_attn.scale if hasattr(old_attn, "scale") else None

        blk.attn = SinterTPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            mlp_hidden_ratio=mlp_hidden_ratio,
            sinter_A=sinter_A,
            sinter_omega=sinter_omega,
        )




def replace_vit_attn_with_contextual_tpa(
    vit_model: VisionTransformer,
    rank_q: int,
    rank_k: int,
    rank_v: int,
    contextA: bool = True,
):
    """
    把 timm ViT 的每个 block.attn 换成：ContextualCPAttention（简化 CP 版 TPA）
      - contextA=True  -> contextual A, non-contextual B
      - contextA=False -> non-contextual A, contextual B
    """
    for blk in vit_model.blocks:
        old_attn = blk.attn

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = old_attn.scale if hasattr(old_attn, "scale") else None

        blk.attn = ContextualCPAttention(
            dim=dim,
            num_heads=num_heads,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            contextA=contextA,
        )


def replace_vit_attn_with_focus_tpa(
    vit_model: VisionTransformer,
    rank: int = 2,
    gamma: float = 1.0,
    sigma=None,
    gate_scale: float = 1.0,
):
    """
    把 timm ViT 的每个 block.attn 换成：FocusTPAAttention
    （full MHA + low-rank True TPA + spatial gate）
    """
    # timm 的 ViT 有 patch_embed.grid_size
    grid_size = vit_model.patch_embed.grid_size  # (Hp, Wp)

    for blk in vit_model.blocks:
        old_attn = blk.attn

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = old_attn.scale if hasattr(old_attn, "scale") else None

        blk.attn = FocusTPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank=rank,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            grid_size=grid_size,
            gamma=gamma,
            sigma=sigma,
            gate_scale=gate_scale,
        )



def replace_vit_attn_with_focus_contextual_tpa(
    vit_model: VisionTransformer,
    rank: int = 2,
    gamma: float = 1.0,
    sigma=None,
    gate_scale: float = 1.0,
    contextA: bool = True,
):
    """
    把 timm ViT 的每个 block.attn 换成：FocusContextualTPAAttention
    （full MHA + contextual TPA + spatial gate）
    """
    grid_size = vit_model.patch_embed.grid_size  # (Hp, Wp)

    for blk in vit_model.blocks:
        old_attn = blk.attn

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = old_attn.scale if hasattr(old_attn, "scale") else None

        blk.attn = FocusContextualTPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank=rank,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            grid_size=grid_size,
            gamma=gamma,
            sigma=sigma,
            gate_scale=gate_scale,
            contextA=contextA,
        )


# ================== ✅ 重点：新的 Tucker 版本 ==================

def replace_vit_attn_with_contextual_tucker_tpa(
    vit_model: VisionTransformer,
    rank_q: int,
    rank_k: int,
    rank_v: int,
):
    """
    把 timm ViT 的每个 block.attn 换成：ContextualTuckerTPAAttention（统一版）

    接口对齐为 rank_q / rank_k / rank_v，
    这样你可以做到：

        - TPA:    rank_q, rank_k, rank_v
        - Tucker: rank_q, rank_k, rank_v（完全一致）

    并且在两阶段训练里：
      - Phase 1: cp_only=True, 只用 A/B（函数族 = TPA）
      - Phase 2: cp_only=False, 用 A/G/B^T（Tucker 提升 expressive power）
    """
    for blk in vit_model.blocks:
        old_attn = blk.attn

        if not hasattr(old_attn, "qkv"):
            raise ValueError("当前 block.attn 没有 qkv 属性，可能不是标准 timm ViT Attention。")

        dim = old_attn.qkv.in_features
        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qkv_bias = old_attn.qkv.bias is not None
        qk_scale = getattr(old_attn, "scale", None)

        blk.attn = ContextualTuckerTPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
        )


def replace_vit_attn_with_static_tucker_tpa(
    vit_model: VisionTransformer,
    rank_q: int,
    rank_k: int,
    rank_v: int,
    progressive_g: bool = True,
    init_lambda: float = 1.0,
):
    """
    如果你还保留 StaticCoreTuckerTPAAttention 的话，
    这里也可以按新的 rank_q / rank_k / rank_v 接口改。
    不用的话可以整个函数删掉。
    """
    for blk in vit_model.blocks:
        old_attn = blk.attn

        if hasattr(old_attn, "qkv"):
            dim = old_attn.qkv.in_features
            qkv_bias = old_attn.qkv.bias is not None
        else:
            dim = old_attn.dim
            qkv_bias = True

        num_heads = old_attn.num_heads
        attn_drop = float(old_attn.attn_drop.p)
        proj_drop = float(old_attn.proj_drop.p)
        qk_scale = getattr(old_attn, "scale", None)

        blk.attn = StaticCoreTuckerTPAAttention(
            dim=dim,
            num_heads=num_heads,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            progressive_g=progressive_g,
            init_lambda=init_lambda,
        )


In [18]:

class ViTClassifier(nn.Module):
    def __init__(
        self,
        num_classes,
        model_name: str = "vit_tiny_patch16_224",
        pretrained: bool = False,

        # 选择注意力类型：
        # "mha" / "tpa" / "nonlinear_tpa" / "sinter_tpa" /
        # "contextual_tpa" / "focus_tpa" / "focus_contextual_tpa" / "tucker_tpa"
        attn_type: str = "mha",

        # ====== TPA / Contextual-TPA 的 rank 设置 ======
        rank_q: int = 16,
        rank_k: int = 2,
        rank_v: int = 2,
        contextual_tpa_A: bool = True,   # 对 contextual_tpa: True=contextual A, False=contextual B

        # ====== Nonlinear / Sinter TPA 的 MLP 设置 ======
        nonlinear_mlp_hidden_ratio: float = 2.0,  # hidden_dim = ratio * head_dim

        # Sinter 激活的超参（只在 attn_type == "sinter_tpa" 时使用）
        sinter_A: float = 5e-5,
        sinter_omega: float = 1e4,

        # ====== Focus-(Contextual)-TPA 的参数 ======
        focus_rank: int = 2,
        focus_gamma: float = 1.0,
        focus_gate_scale: float = 1.0,
        focus_contextual_A: bool = True, # 对 focus_contextual_tpa：同上，控制 A/B 谁 contextual

        # ====== Tucker-TPA 的 rank 设置 ======
        # 对于新的 contextual Tucker-TPA：
        #   - 实际只用到 rank_head / rank_channel
        #   - rank_token 在这里可以暂时忽略（兼容原配置）
        rank_token: int = 16,
        rank_head: int = 8,
        rank_channel: int = 8,

        nonlinear_mlp_on: str = "qkv"
    ):
        super().__init__()

        # 创建基础 ViT 模型
        self.vit = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes,
        )

        # 根据 attn_type 替换 attention
        if attn_type == "mha":
            # 原生 MHA，什么都不做
            pass

        elif attn_type == "tpa":
            # True TPA（A/B 全 contextual，两向量 outer product）
            replace_vit_attn_with_tpa_true(
                self.vit,
                rank_q=rank_q,
                rank_k=rank_k,
                rank_v=rank_v,
            )

        elif attn_type == "nonlinear_tpa":
            # True TPA + head 内一层 MLP 非线性：Q = MLP(1/R A^T B)
            replace_vit_attn_with_nonlinear_tpa(
                self.vit,
                rank_q=rank_q,
                rank_k=rank_k,
                rank_v=rank_v,
                mlp_hidden_ratio=nonlinear_mlp_hidden_ratio,
                mlp_on=nonlinear_mlp_on,
            )
        elif attn_type == "headwise_nonlinear_tpa":
            replace_vit_attn_with_headwise_nonlinear_tpa(
        self.vit,
        rank_q=rank_q,
        rank_k=rank_k,
        rank_v=rank_v,
        mlp_hidden_ratio=nonlinear_mlp_hidden_ratio,
        mlp_on=nonlinear_mlp_on,
    )


        elif attn_type == "sinter_tpa":
            # True TPA + head 内一层 MLP(Sinter)：Q = MLP_Sinter(1/R A^T B)
            replace_vit_attn_with_sinter_tpa(
                self.vit,
                rank_q=rank_q,
                rank_k=rank_k,
                rank_v=rank_v,
                mlp_hidden_ratio=nonlinear_mlp_hidden_ratio,
                sinter_A=sinter_A,
                sinter_omega=sinter_omega,
            )

        elif attn_type == "contextual_tpa":
            # 简化版 Contextual TPA（CP 形式）
            replace_vit_attn_with_contextual_tpa(
                self.vit,
                rank_q=rank_q,
                rank_k=rank_k,
                rank_v=rank_v,
                contextA=contextual_tpa_A,
            )

        elif attn_type == "focus_tpa":
            # Focus-TPA（full + low-rank True TPA）
            replace_vit_attn_with_focus_tpa(
                self.vit,
                rank=focus_rank,
                gamma=focus_gamma,
                sigma=None,
                gate_scale=focus_gate_scale,
            )

        elif attn_type == "focus_contextual_tpa":
            # Focus-Contextual-TPA（full + 低秩 contextual TPA）
            replace_vit_attn_with_focus_contextual_tpa(
                self.vit,
                rank=focus_rank,
                gamma=focus_gamma,
                sigma=None,
                gate_scale=focus_gate_scale,
                contextA=focus_contextual_A,
            )

        elif attn_type == "tucker_tpa":
            # Tucker-style Contextual Tensor Product Attention
            replace_vit_attn_with_contextual_tucker_tpa(
                self.vit,
                rank_q=rank_q,
                rank_k=rank_k,
                rank_v=rank_v,
                # 如果你新版本的函数需要 rank_head / rank_channel，
                # 也可以在这里一起传进去：
                # rank_head=rank_head,
                # rank_channel=rank_channel,
            )

        else:
            raise ValueError(
                f"Unknown attn_type: {attn_type}. "
                f"Expected one of ['mha','tpa','nonlinear_tpa','sinter_tpa',"
                f"'contextual_tpa','focus_tpa','focus_contextual_tpa','tucker_tpa']"
            )

    def forward(self, x):
        return self.vit(x)

In [19]:
# ========= Multi-dataset ablation: CIFAR10 -> CIFAR100 =========
import os
import json
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.optim import AdamW

# ------------------ safety: keep deterministic-ish (if you already have set_global_seed) ------------------
try:
    set_global_seed(GLOBAL_SEED)
except Exception:
    pass

# ===================== 通用配置 =====================
data_dir = "./data"
img_size = 224
num_workers = 2
weight_decay = 0.05
pretrained = False

device = "cuda" if torch.cuda.is_available() else "cpu"
use_amp = False  # 保持你之前的设置

# 模型
model_tag  = "tiny"
model_name = "vit_tiny_patch16_224"

# 优化
batch_size = 128
lr = 3e-4

# ===== epoch & rank（按你要求固定）=====
rank_q, rank_k, rank_v = 16, 2, 2

# ===== 你要求的 ratio 配置（修正 RATIO_MAP）=====
#   KV (indep): 1.5
#   KV_shared: 3.0
#   QKV: 1.0
#   Q/K/V 单独：3.0
RATIO_MAP = {
    "q": 3.0,
    "k": 3.0,
    "v": 3.0,
    "kv": 1.5,
    "kv_shared": 3.0,
    "qkv": 1.0,

    # NEW: head-wise
    "hw_kv": 0.5,
    "hw_kv_shared": 1.0,
    "hw_qkv": 0.33,
}

DATASETS = ["cifar10", "cifar100"]

# ===== 每个数据集的 epoch 数（按你要求：CIFAR10=40, CIFAR100=20）=====
DATASET_EPOCHS = {
    "cifar10": 40,
    "cifar100": 20,
}

# ===== 实验列表（两套数据集一致，按你指定顺序）=====
# 注意：MHA 的 attn_type 这里先写占位符 "__MHA__"，下面会自动探测实际字符串
EXPERIMENT_LIST = [
    {"name": "MHA_baseline",   "attn_type": "__MHA__",         "mlp_on": "none",      "mlp_ratio": 0.0},
    {"name": "TPA_r1622",      "attn_type": "tpa",             "mlp_on": "none",      "mlp_ratio": 0.0},

    {"name": "NonlinearTPA_KV",        "attn_type": "nonlinear_tpa", "mlp_on": "kv",        "mlp_ratio": RATIO_MAP["kv"]},
    {"name": "NonlinearTPA_KV_shared", "attn_type": "nonlinear_tpa", "mlp_on": "kv_shared", "mlp_ratio": RATIO_MAP["kv_shared"]},

    # NEW: head-wise KV variants
    {"name": "NonlinearTPA_HW_KV",        "attn_type": "headwise_nonlinear_tpa", "mlp_on": "qkv",        "mlp_ratio": RATIO_MAP["hw_qkv"]},
    {"name": "NonlinearTPA_HW_KV",        "attn_type": "headwise_nonlinear_tpa", "mlp_on": "kv",        "mlp_ratio": RATIO_MAP["hw_kv"]},
    {"name": "NonlinearTPA_HW_KV_shared", "attn_type": "headwise_nonlinear_tpa", "mlp_on": "kv_shared", "mlp_ratio": RATIO_MAP["hw_kv_shared"]},

    {"name": "NonlinearTPA_QKV",       "attn_type": "nonlinear_tpa", "mlp_on": "qkv",       "mlp_ratio": RATIO_MAP["qkv"]},

    {"name": "NonlinearTPA_Q",         "attn_type": "nonlinear_tpa", "mlp_on": "q",         "mlp_ratio": RATIO_MAP["q"]},
    {"name": "NonlinearTPA_K",         "attn_type": "nonlinear_tpa", "mlp_on": "k",         "mlp_ratio": RATIO_MAP["k"]},
    {"name": "NonlinearTPA_V",         "attn_type": "nonlinear_tpa", "mlp_on": "v",         "mlp_ratio": RATIO_MAP["v"]},
]

EXPERIMENTS = {
    "cifar10":  EXPERIMENT_LIST,
    "cifar100": EXPERIMENT_LIST,
}

# ------------------ helper: top-k val accuracy (for table) ------------------
@torch.no_grad()
def compute_val_topk_acc(model: nn.Module, loader, device: str, k: int = 5) -> float:
    model.eval()
    correct = 0
    total = 0
    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        logits = model(images)
        _, pred = logits.topk(k, dim=1, largest=True, sorted=True)  # (B,k)
        correct += pred.eq(targets.view(-1, 1)).any(dim=1).sum().item()
        total += targets.numel()
    return correct / max(total, 1)

def _ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path

def _write_json(path: str, obj):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2)

def _save_table_csv_md(df: pd.DataFrame, csv_path: str, md_path: str):
    df.to_csv(csv_path, index=False)
    with open(md_path, "w", encoding="utf-8") as f:
        f.write(df.to_markdown(index=False))

def _make_results_table(rows: list, params_tpa_ref: int):
    out = []
    for r in rows:
        params_m = r["params"] / 1e6
        delta_m  = (r["params"] - params_tpa_ref) / 1e6
        out.append({
            "Dataset": r["dataset"],
            "Method":  r["name"],
            "Attn":    r["attn_type"],
            "MLP_on":  r.get("mlp_on", "none"),
            "MLP_ratio": r.get("mlp_ratio", None),
            "MLP_hidden": r.get("mlp_hidden_dim", None),
            "Params(M)": round(params_m, 3),
            "ΔParams vs TPA(M)": round(delta_m, 3),
            "Best epoch": int(r["best_epoch"] + 1),
            "Top-1 Val Acc (best)": round(float(r["best_val_acc"]), 4),
            "Top-5 Val Acc (final)": (None if r.get("val_top5_acc_final") is None else round(float(r["val_top5_acc_final"]), 4)),
        })
    return pd.DataFrame(out)

def _plot_combined_curves(dataset_dir: str, dataset_name: str, rows: list):
    # 1) val acc
    plt.figure()
    for r in rows:
        epochs_range = range(1, len(r["val_acc_curve"]) + 1)
        plt.plot(epochs_range, r["val_acc_curve"], label=r["name"])
    plt.xlabel("Epoch")
    plt.ylabel("Val accuracy")
    plt.title(f"{dataset_name} - {model_tag}: Val accuracy comparison")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(dataset_dir, "combined_val_acc.png"))
    plt.close()

    # 2) val loss
    plt.figure()
    for r in rows:
        epochs_range = range(1, len(r["val_loss_curve"]) + 1)
        plt.plot(epochs_range, r["val_loss_curve"], label=r["name"])
    plt.xlabel("Epoch")
    plt.ylabel("Val loss")
    plt.title(f"{dataset_name} - {model_tag}: Val loss comparison")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(dataset_dir, "combined_val_loss.png"))
    plt.close()

# ===== 解析 GLOBAL_SEED（用于 run 文件夹命名 & ckpt）=====
try:
    _SEED_TAG = int(GLOBAL_SEED)
except Exception:
    _SEED_TAG = "NA"

# ===== 自动探测 MHA 的 attn_type 字符串（避免你项目里命名不一致导致崩）=====
def _resolve_mha_attn_type(num_classes: int):
    # 你项目里常见命名候选（如不匹配，你可以把你项目里实际用的字符串加到这里）
    candidates = ["mha", "mhsa", "vanilla", "baseline", "standard", "attn"]
    for cand in candidates:
        try:
            m = ViTClassifier(
                num_classes=num_classes,
                model_name=model_name,
                pretrained=pretrained,
                attn_type=cand,
                rank_q=rank_q, rank_k=rank_k, rank_v=rank_v,
            ).to(device)
            del m
            if device == "cuda":
                torch.cuda.empty_cache()
            return cand
        except Exception:
            continue
    raise RuntimeError(
        "Cannot resolve MHA attn_type. Please add the correct string to candidates in _resolve_mha_attn_type()."
    )

# ===================== result 目录 / run 文件夹 =====================
_ensure_dir("result")

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = (
    f"ablation_seed{_SEED_TAG}_{timestamp}_"
    f"{model_tag}_{model_name}_"
    f"EpC10{DATASET_EPOCHS['cifar10']}_C100{DATASET_EPOCHS['cifar100']}_"
    f"bs{batch_size}_lr{lr}_wd{weight_decay}_"
    f"rq{rank_q}_rk{rank_k}_rv{rank_v}_"
    f"ratios_QKV{RATIO_MAP['qkv']}_KV{RATIO_MAP['kv']}_KVsh{RATIO_MAP['kv_shared']}_single{RATIO_MAP['q']}"
)
run_dir = _ensure_dir(os.path.join("result", run_name))
checkpoints_root = _ensure_dir(os.path.join(run_dir, "checkpoints"))
print(f"\n\n######################## Run dir: {run_dir} ########################")

hparams = {
    "datasets": DATASETS,
    "dataset_epochs": DATASET_EPOCHS,
    "model_tag": model_tag,
    "model_name": model_name,
    "img_size": img_size,
    "batch_size": batch_size,
    "lr": lr,
    "weight_decay": weight_decay,
    "pretrained": pretrained,
    "rank_q": rank_q,
    "rank_k": rank_k,
    "rank_v": rank_v,
    "ratio_map": RATIO_MAP,
    "experiments": EXPERIMENTS,
    "timestamp": timestamp,
    "GLOBAL_SEED": _SEED_TAG,
}
_write_json(os.path.join(run_dir, "hparams.json"), hparams)

global_rows = []

# ===================== 主循环：dataset × experiment =====================
for dataset_name in DATASETS:
    total_epochs = DATASET_EPOCHS[dataset_name]
    print(f"\n\n==================== DATASET: {dataset_name} (epochs={total_epochs}) ====================")

    dataset_dir = _ensure_dir(os.path.join(run_dir, dataset_name))
    _write_json(os.path.join(dataset_dir, "hparams.json"), {**hparams, "dataset_name": dataset_name, "total_epochs": total_epochs})

    # 先拿 num_classes
    train_loader, val_loader, num_classes = get_loaders(
        dataset_name=dataset_name,
        data_dir=data_dir,
        batch_size=batch_size,
        img_size=img_size,
        num_workers=num_workers,
    )

    # 解析 MHA 的 attn_type（只做一次）
    mha_attn_type = _resolve_mha_attn_type(num_classes)
    print(f"[Info] Resolved MHA attn_type = {mha_attn_type}")

    # baseline TPA params (per dataset：因为 num_classes 会影响分类头参数)
    tpa_ref_model = ViTClassifier(
        num_classes=num_classes,
        model_name=model_name,
        pretrained=pretrained,
        attn_type="tpa",
        rank_q=rank_q,
        rank_k=rank_k,
        rank_v=rank_v,
    ).to(device)
    params_tpa_ref = sum(p.numel() for p in tpa_ref_model.parameters())
    del tpa_ref_model
    if device == "cuda":
        torch.cuda.empty_cache()

    dataset_rows = []

    for exp in EXPERIMENTS[dataset_name]:
        exp_name  = exp["name"]
        attn_type = exp["attn_type"]
        mlp_on    = exp["mlp_on"]
        mlp_ratio = exp["mlp_ratio"]

        # 替换 MHA 占位符
        if attn_type == "__MHA__":
            attn_type = mha_attn_type

        print(f"\n-------------------- [{dataset_name}] Experiment: {exp_name} --------------------")
        exp_dir = _ensure_dir(os.path.join(dataset_dir, exp_name))
        exp_ckpt_dir = _ensure_dir(os.path.join(exp_dir, "checkpoints"))

        # 训练（严格保持你现成的 run_small_spectrum_experiment 调用方式）
        model, hist = run_small_spectrum_experiment(
            dataset_name=dataset_name,
            num_workers=2,
            model_name=model_name,
            data_dir=data_dir,
            batch_size=batch_size,
            img_size=img_size,
            lr=lr,
            weight_decay=weight_decay,

            device=device,
            attn_type=attn_type,
            total_epochs=total_epochs,
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            mlp_ratio=mlp_ratio,
            mlp_on=mlp_on,
            block_idx=5,
            top_k=8,
            num_batches_spec=1,
        )

        # final top-5
        try:
            val_top5 = compute_val_topk_acc(model, val_loader, device=device, k=5)
        except Exception:
            val_top5 = None

        train_loss_curve = hist["train_loss_curve"]
        val_loss_curve   = hist["val_loss_curve"]
        train_acc_curve  = hist["train_acc_curve"]
        val_acc_curve    = hist["val_acc_curve"]
        best_val_acc     = float(hist["best_val_acc"])
        best_epoch       = int(np.argmax(val_acc_curve))

        total_params = sum(p.numel() for p in model.parameters())
        first_attn = model.vit.blocks[0].attn
        dim_attn = getattr(first_attn, "dim", None) or first_attn.qkv.in_features
        num_heads_attn = first_attn.num_heads
        head_dim_attn = dim_attn // num_heads_attn

        # 只有 nonlinear_tpa 才有 mlp_hidden_dim
        mlp_hidden_dim = None
        if attn_type == "nonlinear_tpa":
            try:
                mlp_hidden_dim = int(head_dim_attn * float(mlp_ratio))
            except Exception:
                mlp_hidden_dim = None

        # log
        log_lines = []
        for ep in range(total_epochs):
            log_lines.append(
                f"Epoch {ep+1}/{total_epochs} | "
                f"train loss: {train_loss_curve[ep]:.4f}, train acc: {train_acc_curve[ep]:.3f} | "
                f"val loss: {val_loss_curve[ep]:.4f}, val acc: {val_acc_curve[ep]:.3f}"
            )
        with open(os.path.join(exp_dir, "log.txt"), "w", encoding="utf-8") as f:
            f.write("\n".join(log_lines))

        # ===== checkpoint：保存尽可能完整的状态，支持“等价续训” =====
        ckpt = {
            "model_state": model.state_dict(),
            "optimizer_state": hist.get("optimizer_state", None),
            "scheduler_state": hist.get("scheduler_state", None),  # 如果你的训练函数有返回就会存
            "amp_scaler_state": hist.get("amp_scaler_state", None),  # use_amp=False 时一般为 None

            "last_epoch": hist.get("last_epoch", None),
            "dataset": dataset_name,
            "attn_type": attn_type,
            "exp_name": exp_name,
            "rank_q": rank_q, "rank_k": rank_k, "rank_v": rank_v,
            "mlp_on": mlp_on,
            "mlp_ratio": mlp_ratio,
            "total_epochs": total_epochs,
            "GLOBAL_SEED": _SEED_TAG,

            # RNG states（尽量保证 resume “工程等价”）
            "numpy_rng_state": np.random.get_state(),
            "torch_rng_state": torch.get_rng_state(),
            "cuda_rng_state": (torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None),

            # metrics/history
            "best_val_acc": best_val_acc,
            "train_loss_curve": train_loss_curve,
            "val_loss_curve": val_loss_curve,
            "train_acc_curve": train_acc_curve,
            "val_acc_curve": val_acc_curve,
            "val_top5_acc_final": val_top5,
            "hparams": hparams,

            # 方便复现实验/续训时重建同构调用
            "train_call_kwargs": {
                "dataset_name": dataset_name,
                "num_workers": 2,
                "model_name": model_name,
                "data_dir": data_dir,
                "batch_size": batch_size,
                "img_size": img_size,
                "lr": lr,
                "weight_decay": weight_decay,
                "device": device,
                "attn_type": attn_type,
                "total_epochs": total_epochs,
                "rank_q": rank_q,
                "rank_k": rank_k,
                "rank_v": rank_v,
                "mlp_ratio": mlp_ratio,
                "mlp_on": mlp_on,
                "block_idx": 5,
                "top_k": 8,
                "num_batches_spec": 1,
            }
        }

        # 1) 保持你原逻辑：exp_dir/ckpt.pt
        torch.save(ckpt, os.path.join(exp_dir, "ckpt.pt"))
        # 2) 按你要求：全部也落在 checkpoints 目录里（每个 exp 一份）
        torch.save(ckpt, os.path.join(exp_ckpt_dir, "ckpt.pt"))
        # 3) run 级别汇总 checkpoints（便于你后续脚本统一读取）
        torch.save(ckpt, os.path.join(checkpoints_root, f"{dataset_name}__{exp_name}__seed{_SEED_TAG}.pt"))

        # plots（保持原有逻辑）
        epochs_range = range(1, total_epochs + 1)

        plt.figure()
        plt.plot(epochs_range, train_loss_curve, label="train loss")
        plt.plot(epochs_range, val_loss_curve, label="val loss")
        plt.xlabel("Epoch"); plt.ylabel("Loss")
        plt.title(f"{dataset_name} - {model_tag} - {exp_name}: loss convergence")
        plt.grid(True); plt.legend(); plt.tight_layout()
        plt.savefig(os.path.join(exp_dir, "loss_curve.png"))
        plt.close()

        plt.figure()
        plt.plot(epochs_range, val_acc_curve, label=f"{exp_name} val acc")
        plt.xlabel("Epoch"); plt.ylabel("Val accuracy")
        plt.title(f"{dataset_name} - {model_tag} - {exp_name}: val acc")
        plt.grid(True); plt.legend(); plt.tight_layout()
        plt.savefig(os.path.join(exp_dir, "val_acc_curve.png"))
        plt.close()

        # update rows
        row = {
            "dataset": dataset_name,
            "name": exp_name,
            "attn_type": attn_type,
            "mlp_on": mlp_on,
            "mlp_ratio": (None if attn_type != "nonlinear_tpa" else float(mlp_ratio)),
            "mlp_hidden_dim": mlp_hidden_dim,
            "params": total_params,
            "params_vs_tpa": total_params - params_tpa_ref,
            "best_val_acc": best_val_acc,
            "best_epoch": best_epoch,
            "val_top5_acc_final": val_top5,
            "train_loss_curve": train_loss_curve,
            "val_loss_curve": val_loss_curve,
            "train_acc_curve": train_acc_curve,
            "val_acc_curve": val_acc_curve,
            "GLOBAL_SEED": _SEED_TAG,
        }
        dataset_rows.append(row)
        global_rows.append(row)

        # dataset summary/table/combined plots（每个 experiment 完成就刷新一次，防中途挂）
        dataset_summary = []
        for r in dataset_rows:
            dataset_summary.append({
                k: r[k] for k in [
                    "dataset","name","attn_type","mlp_on","mlp_ratio","mlp_hidden_dim",
                    "params","params_vs_tpa","best_val_acc","best_epoch","val_top5_acc_final",
                    "train_loss_curve","val_loss_curve","train_acc_curve","val_acc_curve","GLOBAL_SEED"
                ]
            })
        _write_json(os.path.join(dataset_dir, "summary.json"), dataset_summary)

        df_dataset = _make_results_table(dataset_rows, params_tpa_ref=params_tpa_ref)
        _save_table_csv_md(
            df_dataset,
            csv_path=os.path.join(dataset_dir, "results_table.csv"),
            md_path=os.path.join(dataset_dir, "results_table.md"),
        )
        _plot_combined_curves(dataset_dir, dataset_name, dataset_rows)

        # global 同步写一份（同样防 crash）
        global_summary = []
        for r in global_rows:
            global_summary.append({
                k: r[k] for k in [
                    "dataset","name","attn_type","mlp_on","mlp_ratio","mlp_hidden_dim",
                    "params","params_vs_tpa","best_val_acc","best_epoch","val_top5_acc_final",
                    "train_loss_curve","val_loss_curve","train_acc_curve","val_acc_curve","GLOBAL_SEED"
                ]
            })
        _write_json(os.path.join(run_dir, "summary.json"), global_summary)

        pd.DataFrame([{
            "Dataset": r["dataset"],
            "Method": r["name"],
            "Attn": r["attn_type"],
            "MLP_on": r.get("mlp_on","none"),
            "MLP_ratio": r.get("mlp_ratio", None),
            "MLP_hidden": r.get("mlp_hidden_dim", None),
            "Params(M)": round(r["params"]/1e6, 3),
            "Top-1 Val Acc (best)": round(float(r["best_val_acc"]), 4),
            "Best epoch": int(r["best_epoch"] + 1),
            "Top-5 Val Acc (final)": (None if r.get("val_top5_acc_final") is None else round(float(r["val_top5_acc_final"]), 4)),
            "GLOBAL_SEED": r.get("GLOBAL_SEED", None),
        } for r in global_rows]).to_csv(os.path.join(run_dir, "results_table.csv"), index=False)




######################## Run dir: result\ablation_seed2_20260105_113812_tiny_vit_tiny_patch16_224_EpC1040_C10020_bs128_lr0.0003_wd0.05_rq16_rk2_rv2_ratios_QKV1.0_KV1.5_KVsh3.0_single3.0 ########################


[Info] Resolved MHA attn_type = mha

-------------------- [cifar10] Experiment: MHA_baseline --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [01:33<00:00,  4.19it/s, loss=1.68]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.41it/s, loss=1.66]


[mha] Epoch 1/40 | train loss: 1.8298, train acc: 0.319 | val loss: 1.6354, val acc: 0.393


Train [2/40]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.27it/s, loss=1.64]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.60it/s, loss=1.42]


[mha] Epoch 2/40 | train loss: 1.5606, train acc: 0.426 | val loss: 1.4687, val acc: 0.467


Train [3/40]: 100%|████████████████████████████████████| 391/391 [01:30<00:00,  4.33it/s, loss=1.28]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.66it/s, loss=1.18]


[mha] Epoch 3/40 | train loss: 1.3918, train acc: 0.494 | val loss: 1.3726, val acc: 0.500


Train [4/40]: 100%|████████████████████████████████████| 391/391 [01:29<00:00,  4.35it/s, loss=1.12]
Val   [4/40]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.64it/s, loss=1.04]


[mha] Epoch 4/40 | train loss: 1.2542, train acc: 0.549 | val loss: 1.2154, val acc: 0.560


Train [5/40]: 100%|████████████████████████████████████| 391/391 [01:30<00:00,  4.33it/s, loss=0.99]
Val   [5/40]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.61it/s, loss=0.86]


[mha] Epoch 5/40 | train loss: 1.1626, train acc: 0.582 | val loss: 1.1725, val acc: 0.574


Train [6/40]: 100%|████████████████████████████████████| 391/391 [01:30<00:00,  4.33it/s, loss=1.18]
Val   [6/40]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.63it/s, loss=1.01]


[mha] Epoch 6/40 | train loss: 1.0966, train acc: 0.609 | val loss: 1.0860, val acc: 0.612


Train [7/40]: 100%|████████████████████████████████████| 391/391 [01:30<00:00,  4.34it/s, loss=1.01]
Val   [7/40]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.20it/s, loss=0.98]


[mha] Epoch 7/40 | train loss: 1.0471, train acc: 0.625 | val loss: 1.0367, val acc: 0.623


Train [8/40]: 100%|███████████████████████████████████| 391/391 [01:30<00:00,  4.32it/s, loss=0.736]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:11<00:00,  6.90it/s, loss=0.975]


[mha] Epoch 8/40 | train loss: 0.9841, train acc: 0.648 | val loss: 1.0054, val acc: 0.635


Train [9/40]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.28it/s, loss=0.953]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.32it/s, loss=0.966]


[mha] Epoch 9/40 | train loss: 0.9424, train acc: 0.662 | val loss: 1.0324, val acc: 0.631


Train [10/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.873]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.55it/s, loss=0.756]


[mha] Epoch 10/40 | train loss: 0.8990, train acc: 0.680 | val loss: 0.9290, val acc: 0.667


Train [11/40]: 100%|███████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=1.07]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.57it/s, loss=0.672]


[mha] Epoch 11/40 | train loss: 0.8479, train acc: 0.699 | val loss: 0.9372, val acc: 0.665


Train [12/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.765]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.56it/s, loss=0.548]


[mha] Epoch 12/40 | train loss: 0.8159, train acc: 0.710 | val loss: 0.8627, val acc: 0.693


Train [13/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.677]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.59it/s, loss=0.659]


[mha] Epoch 13/40 | train loss: 0.7739, train acc: 0.724 | val loss: 0.8466, val acc: 0.698


Train [14/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.681]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.59it/s, loss=0.556]


[mha] Epoch 14/40 | train loss: 0.7435, train acc: 0.737 | val loss: 0.8343, val acc: 0.701


Train [15/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=0.967]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.60it/s, loss=0.535]


[mha] Epoch 15/40 | train loss: 0.7099, train acc: 0.748 | val loss: 0.8849, val acc: 0.691


Train [16/40]: 100%|████████████████████████████████████| 391/391 [01:30<00:00,  4.31it/s, loss=0.5]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.53it/s, loss=0.512]


[mha] Epoch 16/40 | train loss: 0.6807, train acc: 0.760 | val loss: 0.8083, val acc: 0.719


Train [17/40]: 100%|███████████████████████████████████| 391/391 [01:30<00:00,  4.31it/s, loss=0.72]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.65it/s, loss=0.729]


[mha] Epoch 17/40 | train loss: 0.6531, train acc: 0.768 | val loss: 0.7968, val acc: 0.723


Train [18/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.743]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.40it/s, loss=0.481]


[mha] Epoch 18/40 | train loss: 0.6212, train acc: 0.781 | val loss: 0.8111, val acc: 0.716


Train [19/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.33it/s, loss=0.395]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.30it/s, loss=0.546]


[mha] Epoch 19/40 | train loss: 0.5943, train acc: 0.790 | val loss: 0.7795, val acc: 0.735


Train [20/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.32it/s, loss=0.734]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.66it/s, loss=0.455]


[mha] Epoch 20/40 | train loss: 0.5636, train acc: 0.802 | val loss: 0.7690, val acc: 0.734


Train [21/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.27it/s, loss=0.702]
Val   [21/40]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.42it/s, loss=0.46]


[mha] Epoch 21/40 | train loss: 0.5367, train acc: 0.810 | val loss: 0.8132, val acc: 0.722


Train [22/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.31it/s, loss=0.489]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.56it/s, loss=0.475]


[mha] Epoch 22/40 | train loss: 0.5058, train acc: 0.821 | val loss: 0.7944, val acc: 0.729


Train [23/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=0.457]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.28it/s, loss=0.388]


[mha] Epoch 23/40 | train loss: 0.4806, train acc: 0.830 | val loss: 0.7956, val acc: 0.733


Train [24/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.31it/s, loss=0.486]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.56it/s, loss=0.455]


[mha] Epoch 24/40 | train loss: 0.4618, train acc: 0.836 | val loss: 0.8186, val acc: 0.734


Train [25/40]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.55]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.46it/s, loss=0.296]


[mha] Epoch 25/40 | train loss: 0.4402, train acc: 0.844 | val loss: 0.8059, val acc: 0.736


Train [26/40]: 100%|████████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=0.3]
Val   [26/40]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.42it/s, loss=0.49]


[mha] Epoch 26/40 | train loss: 0.4042, train acc: 0.857 | val loss: 0.8173, val acc: 0.741


Train [27/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.32it/s, loss=0.424]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.48it/s, loss=0.507]


[mha] Epoch 27/40 | train loss: 0.3878, train acc: 0.863 | val loss: 0.8005, val acc: 0.744


Train [28/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.28it/s, loss=0.445]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.57it/s, loss=0.609]


[mha] Epoch 28/40 | train loss: 0.3574, train acc: 0.872 | val loss: 0.8190, val acc: 0.739


Train [29/40]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.27]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.34it/s, loss=0.572]


[mha] Epoch 29/40 | train loss: 0.3326, train acc: 0.881 | val loss: 0.8310, val acc: 0.744


Train [30/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=0.214]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.58it/s, loss=0.278]


[mha] Epoch 30/40 | train loss: 0.3110, train acc: 0.889 | val loss: 0.8082, val acc: 0.754


Train [31/40]: 100%|██████████████████████████████████| 391/391 [01:32<00:00,  4.25it/s, loss=0.313]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.59it/s, loss=0.384]


[mha] Epoch 31/40 | train loss: 0.2875, train acc: 0.897 | val loss: 0.8439, val acc: 0.750


Train [32/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.32it/s, loss=0.363]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.60it/s, loss=0.359]


[mha] Epoch 32/40 | train loss: 0.2700, train acc: 0.903 | val loss: 0.8480, val acc: 0.747


Train [33/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=0.264]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.60it/s, loss=0.525]


[mha] Epoch 33/40 | train loss: 0.2340, train acc: 0.916 | val loss: 0.9085, val acc: 0.746


Train [34/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.368]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.55it/s, loss=0.515]


[mha] Epoch 34/40 | train loss: 0.2319, train acc: 0.917 | val loss: 0.9285, val acc: 0.748


Train [35/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=0.391]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.63it/s, loss=0.546]


[mha] Epoch 35/40 | train loss: 0.2090, train acc: 0.925 | val loss: 0.9633, val acc: 0.734


Train [36/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.31it/s, loss=0.233]
Val   [36/40]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.63it/s, loss=0.49]


[mha] Epoch 36/40 | train loss: 0.1910, train acc: 0.933 | val loss: 0.9562, val acc: 0.743


Train [37/40]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.30it/s, loss=0.234]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.53it/s, loss=0.265]


[mha] Epoch 37/40 | train loss: 0.1864, train acc: 0.934 | val loss: 0.9772, val acc: 0.745


Train [38/40]: 100%|█████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.0739]
Val   [38/40]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.59it/s, loss=0.32]


[mha] Epoch 38/40 | train loss: 0.1617, train acc: 0.943 | val loss: 1.0114, val acc: 0.742


Train [39/40]: 100%|█████████████████████████████████| 391/391 [01:30<00:00,  4.32it/s, loss=0.0832]
Val   [39/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.58it/s, loss=0.414]


[mha] Epoch 39/40 | train loss: 0.1577, train acc: 0.944 | val loss: 1.0167, val acc: 0.743


Train [40/40]: 100%|██████████████████████████████████| 391/391 [01:30<00:00,  4.30it/s, loss=0.266]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:10<00:00,  7.57it/s, loss=0.407]


[mha] Epoch 40/40 | train loss: 0.1465, train acc: 0.947 | val loss: 1.0247, val acc: 0.745

-------------------- [cifar10] Experiment: TPA_r1622 --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=1.56]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=1.27]


[tpa] Epoch 1/40 | train loss: 1.7359, train acc: 0.353 | val loss: 1.5742, val acc: 0.423


Train [2/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=1.48]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.46it/s, loss=1.15]


[tpa] Epoch 2/40 | train loss: 1.4755, train acc: 0.462 | val loss: 1.3794, val acc: 0.497


Train [3/40]: 100%|████████████████████████████████████| 391/391 [02:28<00:00,  2.63it/s, loss=1.41]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=1.16]


[tpa] Epoch 3/40 | train loss: 1.3342, train acc: 0.516 | val loss: 1.3128, val acc: 0.525


Train [4/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=1.29]
Val   [4/40]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=1.44]


[tpa] Epoch 4/40 | train loss: 1.2422, train acc: 0.553 | val loss: 1.2981, val acc: 0.529


Train [5/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=1.16]
Val   [5/40]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=1.01]


[tpa] Epoch 5/40 | train loss: 1.1593, train acc: 0.582 | val loss: 1.1230, val acc: 0.595


Train [6/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=1.07]
Val   [6/40]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=1.17]


[tpa] Epoch 6/40 | train loss: 1.0861, train acc: 0.611 | val loss: 1.1449, val acc: 0.591


Train [7/40]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.981]
Val   [7/40]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.46it/s, loss=1.01]


[tpa] Epoch 7/40 | train loss: 1.0279, train acc: 0.633 | val loss: 1.0773, val acc: 0.613


Train [8/40]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.852]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.48it/s, loss=0.823]


[tpa] Epoch 8/40 | train loss: 0.9651, train acc: 0.655 | val loss: 1.0165, val acc: 0.635


Train [9/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=1.03]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=0.917]


[tpa] Epoch 9/40 | train loss: 0.9120, train acc: 0.674 | val loss: 1.0141, val acc: 0.643


Train [10/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.791]
Val   [10/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.48it/s, loss=0.76]


[tpa] Epoch 10/40 | train loss: 0.8699, train acc: 0.689 | val loss: 0.9527, val acc: 0.660


Train [11/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.762]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=0.724]


[tpa] Epoch 11/40 | train loss: 0.8289, train acc: 0.703 | val loss: 0.9215, val acc: 0.674


Train [12/40]: 100%|██████████████████████████████████| 391/391 [02:28<00:00,  2.63it/s, loss=0.765]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=0.842]


[tpa] Epoch 12/40 | train loss: 0.7874, train acc: 0.720 | val loss: 0.9380, val acc: 0.662


Train [13/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.8]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=0.735]


[tpa] Epoch 13/40 | train loss: 0.7473, train acc: 0.733 | val loss: 0.9126, val acc: 0.683


Train [14/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.564]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=0.989]


[tpa] Epoch 14/40 | train loss: 0.7076, train acc: 0.747 | val loss: 0.9058, val acc: 0.688


Train [15/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.727]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=0.974]


[tpa] Epoch 15/40 | train loss: 0.6690, train acc: 0.761 | val loss: 0.8724, val acc: 0.694


Train [16/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.699]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=0.649]


[tpa] Epoch 16/40 | train loss: 0.6348, train acc: 0.772 | val loss: 0.8944, val acc: 0.693


Train [17/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.874]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.42it/s, loss=0.634]


[tpa] Epoch 17/40 | train loss: 0.6081, train acc: 0.784 | val loss: 0.8937, val acc: 0.693


Train [18/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.479]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=0.591]


[tpa] Epoch 18/40 | train loss: 0.5600, train acc: 0.800 | val loss: 0.8358, val acc: 0.716


Train [19/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.772]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.48it/s, loss=0.512]


[tpa] Epoch 19/40 | train loss: 0.5233, train acc: 0.813 | val loss: 0.9035, val acc: 0.694


Train [20/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.489]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=0.474]


[tpa] Epoch 20/40 | train loss: 0.4920, train acc: 0.825 | val loss: 0.8651, val acc: 0.712


Train [21/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.443]
Val   [21/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=0.465]


[tpa] Epoch 21/40 | train loss: 0.4550, train acc: 0.838 | val loss: 0.8783, val acc: 0.716


Train [22/40]: 100%|██████████████████████████████████| 391/391 [02:28<00:00,  2.63it/s, loss=0.591]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=0.593]


[tpa] Epoch 22/40 | train loss: 0.4213, train acc: 0.849 | val loss: 0.9414, val acc: 0.708


Train [23/40]: 100%|██████████████████████████████████| 391/391 [02:28<00:00,  2.63it/s, loss=0.503]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.53it/s, loss=0.809]


[tpa] Epoch 23/40 | train loss: 0.4002, train acc: 0.856 | val loss: 0.9572, val acc: 0.699


Train [24/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.386]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=0.617]


[tpa] Epoch 24/40 | train loss: 0.3560, train acc: 0.873 | val loss: 0.9476, val acc: 0.710


Train [25/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.356]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=0.565]


[tpa] Epoch 25/40 | train loss: 0.3297, train acc: 0.881 | val loss: 0.9817, val acc: 0.710


Train [26/40]: 100%|██████████████████████████████████| 391/391 [02:28<00:00,  2.63it/s, loss=0.295]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=0.711]


[tpa] Epoch 26/40 | train loss: 0.3117, train acc: 0.888 | val loss: 0.9561, val acc: 0.723


Train [27/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.206]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.48it/s, loss=0.492]


[tpa] Epoch 27/40 | train loss: 0.2747, train acc: 0.902 | val loss: 0.9944, val acc: 0.716


Train [28/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.181]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.48it/s, loss=0.507]


[tpa] Epoch 28/40 | train loss: 0.2539, train acc: 0.910 | val loss: 1.0144, val acc: 0.718


Train [29/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.291]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.48it/s, loss=0.615]


[tpa] Epoch 29/40 | train loss: 0.2380, train acc: 0.914 | val loss: 1.0415, val acc: 0.713


Train [30/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.372]
Val   [30/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=1.08]


[tpa] Epoch 30/40 | train loss: 0.2268, train acc: 0.919 | val loss: 1.1053, val acc: 0.714


Train [31/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.186]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=0.737]


[tpa] Epoch 31/40 | train loss: 0.1925, train acc: 0.932 | val loss: 1.1432, val acc: 0.706


Train [32/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.214]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=0.689]


[tpa] Epoch 32/40 | train loss: 0.1842, train acc: 0.934 | val loss: 1.1424, val acc: 0.715


Train [33/40]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.2]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.48it/s, loss=0.515]


[tpa] Epoch 33/40 | train loss: 0.1784, train acc: 0.936 | val loss: 1.1918, val acc: 0.710


Train [34/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.207]
Val   [34/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=0.81]


[tpa] Epoch 34/40 | train loss: 0.1637, train acc: 0.942 | val loss: 1.1703, val acc: 0.717


Train [35/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.132]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=0.835]


[tpa] Epoch 35/40 | train loss: 0.1445, train acc: 0.948 | val loss: 1.2090, val acc: 0.710


Train [36/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.141]
Val   [36/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=0.453]


[tpa] Epoch 36/40 | train loss: 0.1473, train acc: 0.949 | val loss: 1.2330, val acc: 0.715


Train [37/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.248]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=0.541]


[tpa] Epoch 37/40 | train loss: 0.1446, train acc: 0.949 | val loss: 1.1964, val acc: 0.721


Train [38/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=0.103]
Val   [38/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=1.22]


[tpa] Epoch 38/40 | train loss: 0.1324, train acc: 0.954 | val loss: 1.2328, val acc: 0.709


Train [39/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.135]
Val   [39/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=1.05]


[tpa] Epoch 39/40 | train loss: 0.1077, train acc: 0.963 | val loss: 1.2777, val acc: 0.706


Train [40/40]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.126]
Val   [40/40]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.49it/s, loss=1.06]


[tpa] Epoch 40/40 | train loss: 0.1320, train acc: 0.953 | val loss: 1.2661, val acc: 0.711

-------------------- [cifar10] Experiment: NonlinearTPA_KV --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=1.53]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=1.28]


[nonlinear_tpa] Epoch 1/40 | train loss: 1.7839, train acc: 0.328 | val loss: 1.6287, val acc: 0.388


Train [2/40]: 100%|████████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=1.59]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=1.34]


[nonlinear_tpa] Epoch 2/40 | train loss: 1.5038, train acc: 0.446 | val loss: 1.4506, val acc: 0.471


Train [3/40]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.45]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.11]


[nonlinear_tpa] Epoch 3/40 | train loss: 1.3709, train acc: 0.501 | val loss: 1.3891, val acc: 0.496


Train [4/40]: 100%|█████████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=1.2]
Val   [4/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=1.05]


[nonlinear_tpa] Epoch 4/40 | train loss: 1.2734, train acc: 0.539 | val loss: 1.2779, val acc: 0.537


Train [5/40]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.15]
Val   [5/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.788]


[nonlinear_tpa] Epoch 5/40 | train loss: 1.1933, train acc: 0.571 | val loss: 1.1859, val acc: 0.573


Train [6/40]: 100%|████████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=1.04]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.812]


[nonlinear_tpa] Epoch 6/40 | train loss: 1.1221, train acc: 0.596 | val loss: 1.1117, val acc: 0.601


Train [7/40]: 100%|████████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=1.15]
Val   [7/40]: 100%|███████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=1.1]


[nonlinear_tpa] Epoch 7/40 | train loss: 1.0603, train acc: 0.620 | val loss: 1.0781, val acc: 0.610


Train [8/40]: 100%|███████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.971]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.878]


[nonlinear_tpa] Epoch 8/40 | train loss: 1.0038, train acc: 0.640 | val loss: 1.0474, val acc: 0.625


Train [9/40]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.842]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.706]


[nonlinear_tpa] Epoch 9/40 | train loss: 0.9442, train acc: 0.661 | val loss: 1.0180, val acc: 0.634


Train [10/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.801]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.693]


[nonlinear_tpa] Epoch 10/40 | train loss: 0.9009, train acc: 0.678 | val loss: 1.0054, val acc: 0.644


Train [11/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.793]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.841]


[nonlinear_tpa] Epoch 11/40 | train loss: 0.8566, train acc: 0.692 | val loss: 0.9111, val acc: 0.676


Train [12/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.626]
Val   [12/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.85]


[nonlinear_tpa] Epoch 12/40 | train loss: 0.8176, train acc: 0.706 | val loss: 0.9383, val acc: 0.668


Train [13/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.699]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.623]


[nonlinear_tpa] Epoch 13/40 | train loss: 0.7833, train acc: 0.720 | val loss: 0.9096, val acc: 0.678


Train [14/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.748]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.568]


[nonlinear_tpa] Epoch 14/40 | train loss: 0.7523, train acc: 0.730 | val loss: 0.8866, val acc: 0.687


Train [15/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.33it/s, loss=0.793]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.584]


[nonlinear_tpa] Epoch 15/40 | train loss: 0.7124, train acc: 0.745 | val loss: 0.8716, val acc: 0.690


Train [16/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.691]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.441]


[nonlinear_tpa] Epoch 16/40 | train loss: 0.6828, train acc: 0.755 | val loss: 0.8555, val acc: 0.703


Train [17/40]: 100%|███████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.68]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.553]


[nonlinear_tpa] Epoch 17/40 | train loss: 0.6503, train acc: 0.770 | val loss: 0.8615, val acc: 0.699


Train [18/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.419]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.621]


[nonlinear_tpa] Epoch 18/40 | train loss: 0.6151, train acc: 0.780 | val loss: 0.8813, val acc: 0.702


Train [19/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.625]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.565]


[nonlinear_tpa] Epoch 19/40 | train loss: 0.5863, train acc: 0.789 | val loss: 0.8487, val acc: 0.711


Train [20/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.655]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.617]


[nonlinear_tpa] Epoch 20/40 | train loss: 0.5670, train acc: 0.798 | val loss: 0.8542, val acc: 0.710


Train [21/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.698]
Val   [21/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.87]


[nonlinear_tpa] Epoch 21/40 | train loss: 0.5231, train acc: 0.815 | val loss: 0.8608, val acc: 0.709


Train [22/40]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.45]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.513]


[nonlinear_tpa] Epoch 22/40 | train loss: 0.5038, train acc: 0.820 | val loss: 0.8702, val acc: 0.706


Train [23/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.441]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.621]


[nonlinear_tpa] Epoch 23/40 | train loss: 0.4664, train acc: 0.835 | val loss: 0.8712, val acc: 0.712


Train [24/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.439]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.511]


[nonlinear_tpa] Epoch 24/40 | train loss: 0.4406, train acc: 0.842 | val loss: 0.8858, val acc: 0.710


Train [25/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.574]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.556]


[nonlinear_tpa] Epoch 25/40 | train loss: 0.4127, train acc: 0.852 | val loss: 0.9115, val acc: 0.714


Train [26/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.375]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.884]


[nonlinear_tpa] Epoch 26/40 | train loss: 0.3810, train acc: 0.864 | val loss: 0.9440, val acc: 0.710


Train [27/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.33it/s, loss=0.362]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.686]


[nonlinear_tpa] Epoch 27/40 | train loss: 0.3529, train acc: 0.873 | val loss: 0.9393, val acc: 0.719


Train [28/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.33it/s, loss=0.224]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.559]


[nonlinear_tpa] Epoch 28/40 | train loss: 0.3349, train acc: 0.881 | val loss: 0.9164, val acc: 0.718


Train [29/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.321]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.582]


[nonlinear_tpa] Epoch 29/40 | train loss: 0.3025, train acc: 0.891 | val loss: 0.9671, val acc: 0.723


Train [30/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.302]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.414]


[nonlinear_tpa] Epoch 30/40 | train loss: 0.2814, train acc: 0.898 | val loss: 0.9839, val acc: 0.724


Train [31/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.333]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.964]


[nonlinear_tpa] Epoch 31/40 | train loss: 0.2609, train acc: 0.907 | val loss: 1.0889, val acc: 0.713


Train [32/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.376]
Val   [32/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.47]


[nonlinear_tpa] Epoch 32/40 | train loss: 0.2404, train acc: 0.913 | val loss: 1.0539, val acc: 0.709


Train [33/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.341]
Val   [33/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.76]


[nonlinear_tpa] Epoch 33/40 | train loss: 0.2200, train acc: 0.922 | val loss: 1.0986, val acc: 0.709


Train [34/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.156]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.552]


[nonlinear_tpa] Epoch 34/40 | train loss: 0.2044, train acc: 0.929 | val loss: 1.1090, val acc: 0.708


Train [35/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.318]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.517]


[nonlinear_tpa] Epoch 35/40 | train loss: 0.1887, train acc: 0.933 | val loss: 1.1070, val acc: 0.716


Train [36/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.218]
Val   [36/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.952]


[nonlinear_tpa] Epoch 36/40 | train loss: 0.1843, train acc: 0.935 | val loss: 1.1698, val acc: 0.703


Train [37/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.194]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.965]


[nonlinear_tpa] Epoch 37/40 | train loss: 0.1734, train acc: 0.939 | val loss: 1.1488, val acc: 0.710


Train [38/40]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.33it/s, loss=0.268]
Val   [38/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.68]


[nonlinear_tpa] Epoch 38/40 | train loss: 0.1562, train acc: 0.944 | val loss: 1.1857, val acc: 0.711


Train [39/40]: 100%|█████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.0593]
Val   [39/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=1.12]


[nonlinear_tpa] Epoch 39/40 | train loss: 0.1538, train acc: 0.945 | val loss: 1.2153, val acc: 0.714


Train [40/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.33it/s, loss=0.196]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.341]


[nonlinear_tpa] Epoch 40/40 | train loss: 0.1432, train acc: 0.949 | val loss: 1.1702, val acc: 0.718

-------------------- [cifar10] Experiment: NonlinearTPA_KV_shared --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.65]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=1.46]


[nonlinear_tpa] Epoch 1/40 | train loss: 1.7898, train acc: 0.329 | val loss: 1.5724, val acc: 0.413


Train [2/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.44]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=1.35]


[nonlinear_tpa] Epoch 2/40 | train loss: 1.4695, train acc: 0.461 | val loss: 1.3912, val acc: 0.496


Train [3/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.26]
Val   [3/40]: 100%|███████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=1.4]


[nonlinear_tpa] Epoch 3/40 | train loss: 1.3289, train acc: 0.517 | val loss: 1.3360, val acc: 0.522


Train [4/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.15it/s, loss=1.33]
Val   [4/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=1.13]


[nonlinear_tpa] Epoch 4/40 | train loss: 1.2380, train acc: 0.552 | val loss: 1.2222, val acc: 0.555


Train [5/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.97]
Val   [5/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=1.08]


[nonlinear_tpa] Epoch 5/40 | train loss: 1.1558, train acc: 0.585 | val loss: 1.1520, val acc: 0.589


Train [6/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.23]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=0.997]


[nonlinear_tpa] Epoch 6/40 | train loss: 1.0943, train acc: 0.606 | val loss: 1.0929, val acc: 0.603


Train [7/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.11]
Val   [7/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.722]


[nonlinear_tpa] Epoch 7/40 | train loss: 1.0412, train acc: 0.625 | val loss: 1.0557, val acc: 0.617


Train [8/40]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=1.02]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.863]


[nonlinear_tpa] Epoch 8/40 | train loss: 0.9759, train acc: 0.649 | val loss: 1.0383, val acc: 0.629


Train [9/40]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.825]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.96it/s, loss=0.973]


[nonlinear_tpa] Epoch 9/40 | train loss: 0.9374, train acc: 0.662 | val loss: 0.9624, val acc: 0.656


Train [10/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.824]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.645]


[nonlinear_tpa] Epoch 10/40 | train loss: 0.8986, train acc: 0.678 | val loss: 0.9558, val acc: 0.656


Train [11/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.845]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:20<00:00,  3.95it/s, loss=0.741]


[nonlinear_tpa] Epoch 11/40 | train loss: 0.8554, train acc: 0.695 | val loss: 0.9375, val acc: 0.664


Train [12/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.949]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.846]


[nonlinear_tpa] Epoch 12/40 | train loss: 0.8098, train acc: 0.709 | val loss: 0.9242, val acc: 0.669


Train [13/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.532]
Val   [13/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.95it/s, loss=0.46]


[nonlinear_tpa] Epoch 13/40 | train loss: 0.7802, train acc: 0.720 | val loss: 0.8834, val acc: 0.686


Train [14/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.647]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:20<00:00,  3.95it/s, loss=0.782]


[nonlinear_tpa] Epoch 14/40 | train loss: 0.7448, train acc: 0.736 | val loss: 0.8737, val acc: 0.689


Train [15/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.855]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.634]


[nonlinear_tpa] Epoch 15/40 | train loss: 0.7139, train acc: 0.744 | val loss: 0.8363, val acc: 0.705


Train [16/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.592]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.512]


[nonlinear_tpa] Epoch 16/40 | train loss: 0.6841, train acc: 0.753 | val loss: 0.8353, val acc: 0.706


Train [17/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.408]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.677]


[nonlinear_tpa] Epoch 17/40 | train loss: 0.6454, train acc: 0.769 | val loss: 0.8763, val acc: 0.692


Train [18/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.681]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.857]


[nonlinear_tpa] Epoch 18/40 | train loss: 0.6116, train acc: 0.780 | val loss: 0.8528, val acc: 0.706


Train [19/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.437]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.873]


[nonlinear_tpa] Epoch 19/40 | train loss: 0.5819, train acc: 0.792 | val loss: 0.8574, val acc: 0.709


Train [20/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.437]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.776]


[nonlinear_tpa] Epoch 20/40 | train loss: 0.5584, train acc: 0.800 | val loss: 0.8464, val acc: 0.708


Train [21/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.299]
Val   [21/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.96it/s, loss=0.896]


[nonlinear_tpa] Epoch 21/40 | train loss: 0.5288, train acc: 0.810 | val loss: 0.8316, val acc: 0.723


Train [22/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.465]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.671]


[nonlinear_tpa] Epoch 22/40 | train loss: 0.5016, train acc: 0.822 | val loss: 0.8282, val acc: 0.720


Train [23/40]: 100%|███████████████████████████████████| 391/391 [03:01<00:00,  2.15it/s, loss=0.53]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.00it/s, loss=0.696]


[nonlinear_tpa] Epoch 23/40 | train loss: 0.4666, train acc: 0.834 | val loss: 0.8588, val acc: 0.719


Train [24/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.414]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.914]


[nonlinear_tpa] Epoch 24/40 | train loss: 0.4435, train acc: 0.842 | val loss: 0.8529, val acc: 0.721


Train [25/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.593]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.657]


[nonlinear_tpa] Epoch 25/40 | train loss: 0.4110, train acc: 0.853 | val loss: 0.8555, val acc: 0.729


Train [26/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.352]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.96it/s, loss=0.672]


[nonlinear_tpa] Epoch 26/40 | train loss: 0.3850, train acc: 0.861 | val loss: 0.8718, val acc: 0.722


Train [27/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.257]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.801]


[nonlinear_tpa] Epoch 27/40 | train loss: 0.3585, train acc: 0.871 | val loss: 0.9387, val acc: 0.714


Train [28/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.429]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.656]


[nonlinear_tpa] Epoch 28/40 | train loss: 0.3368, train acc: 0.881 | val loss: 0.9471, val acc: 0.719


Train [29/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.268]
Val   [29/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.52]


[nonlinear_tpa] Epoch 29/40 | train loss: 0.3121, train acc: 0.888 | val loss: 0.9119, val acc: 0.724


Train [30/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.364]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.396]


[nonlinear_tpa] Epoch 30/40 | train loss: 0.2907, train acc: 0.895 | val loss: 0.9465, val acc: 0.726


Train [31/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.345]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.96it/s, loss=0.569]


[nonlinear_tpa] Epoch 31/40 | train loss: 0.2763, train acc: 0.902 | val loss: 1.0006, val acc: 0.721


Train [32/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.276]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.452]


[nonlinear_tpa] Epoch 32/40 | train loss: 0.2415, train acc: 0.914 | val loss: 1.0232, val acc: 0.725


Train [33/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.198]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.96it/s, loss=0.718]


[nonlinear_tpa] Epoch 33/40 | train loss: 0.2346, train acc: 0.916 | val loss: 1.0235, val acc: 0.712


Train [34/40]: 100%|███████████████████████████████████| 391/391 [03:01<00:00,  2.15it/s, loss=0.22]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.96it/s, loss=0.526]


[nonlinear_tpa] Epoch 34/40 | train loss: 0.2121, train acc: 0.924 | val loss: 1.0369, val acc: 0.728


Train [35/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.202]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.878]


[nonlinear_tpa] Epoch 35/40 | train loss: 0.1999, train acc: 0.928 | val loss: 1.0317, val acc: 0.731


Train [36/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.129]
Val   [36/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.306]


[nonlinear_tpa] Epoch 36/40 | train loss: 0.1870, train acc: 0.934 | val loss: 1.1228, val acc: 0.719


Train [37/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.165]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:20<00:00,  3.94it/s, loss=0.732]


[nonlinear_tpa] Epoch 37/40 | train loss: 0.1855, train acc: 0.934 | val loss: 1.0630, val acc: 0.726


Train [38/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.238]
Val   [38/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.502]


[nonlinear_tpa] Epoch 38/40 | train loss: 0.1619, train acc: 0.941 | val loss: 1.1489, val acc: 0.712


Train [39/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.265]
Val   [39/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=0.523]


[nonlinear_tpa] Epoch 39/40 | train loss: 0.1603, train acc: 0.943 | val loss: 1.0854, val acc: 0.731


Train [40/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.169]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=0.589]


[nonlinear_tpa] Epoch 40/40 | train loss: 0.1436, train acc: 0.949 | val loss: 1.1432, val acc: 0.725

-------------------- [cifar10] Experiment: NonlinearTPA_HW_KV --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.74]
Val   [1/40]: 100%|███████████████████████████████████████| 79/79 [00:19<00:00,  4.11it/s, loss=1.3]


[headwise_nonlinear_tpa] Epoch 1/40 | train loss: 1.8435, train acc: 0.308 | val loss: 1.6631, val acc: 0.382


Train [2/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.27]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=1.12]


[headwise_nonlinear_tpa] Epoch 2/40 | train loss: 1.5094, train acc: 0.445 | val loss: 1.4737, val acc: 0.464


Train [3/40]: 100%|█████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.5]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=1.38]


[headwise_nonlinear_tpa] Epoch 3/40 | train loss: 1.3664, train acc: 0.502 | val loss: 1.3246, val acc: 0.527


Train [4/40]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.24]
Val   [4/40]: 100%|███████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=1.1]


[headwise_nonlinear_tpa] Epoch 4/40 | train loss: 1.2738, train acc: 0.539 | val loss: 1.2522, val acc: 0.549


Train [5/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.15it/s, loss=1.04]
Val   [5/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=1.01]


[headwise_nonlinear_tpa] Epoch 5/40 | train loss: 1.2138, train acc: 0.564 | val loss: 1.1891, val acc: 0.574


Train [6/40]: 100%|█████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.2]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.888]


[headwise_nonlinear_tpa] Epoch 6/40 | train loss: 1.1433, train acc: 0.587 | val loss: 1.1682, val acc: 0.581


Train [7/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.33]
Val   [7/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.801]


[headwise_nonlinear_tpa] Epoch 7/40 | train loss: 1.0992, train acc: 0.603 | val loss: 1.1452, val acc: 0.593


Train [8/40]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.14]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.791]


[headwise_nonlinear_tpa] Epoch 8/40 | train loss: 1.0536, train acc: 0.621 | val loss: 1.0824, val acc: 0.612


Train [9/40]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.898]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.832]


[headwise_nonlinear_tpa] Epoch 9/40 | train loss: 1.0079, train acc: 0.636 | val loss: 1.0453, val acc: 0.623


Train [10/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.915]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.837]


[headwise_nonlinear_tpa] Epoch 10/40 | train loss: 0.9642, train acc: 0.653 | val loss: 1.0220, val acc: 0.634


Train [11/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.876]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.818]


[headwise_nonlinear_tpa] Epoch 11/40 | train loss: 0.9229, train acc: 0.668 | val loss: 0.9856, val acc: 0.644


Train [12/40]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.88]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.717]


[headwise_nonlinear_tpa] Epoch 12/40 | train loss: 0.8808, train acc: 0.684 | val loss: 0.9753, val acc: 0.655


Train [13/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.859]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.779]


[headwise_nonlinear_tpa] Epoch 13/40 | train loss: 0.8500, train acc: 0.693 | val loss: 0.9986, val acc: 0.648


Train [14/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.772]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.728]


[headwise_nonlinear_tpa] Epoch 14/40 | train loss: 0.8229, train acc: 0.704 | val loss: 0.9126, val acc: 0.675


Train [15/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.727]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.11it/s, loss=0.771]


[headwise_nonlinear_tpa] Epoch 15/40 | train loss: 0.7809, train acc: 0.720 | val loss: 0.9299, val acc: 0.676


Train [16/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.615]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.582]


[headwise_nonlinear_tpa] Epoch 16/40 | train loss: 0.7519, train acc: 0.729 | val loss: 0.9204, val acc: 0.677


Train [17/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.828]
Val   [17/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.66]


[headwise_nonlinear_tpa] Epoch 17/40 | train loss: 0.7197, train acc: 0.741 | val loss: 0.9173, val acc: 0.681


Train [18/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.587]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.11it/s, loss=0.883]


[headwise_nonlinear_tpa] Epoch 18/40 | train loss: 0.6823, train acc: 0.754 | val loss: 0.9462, val acc: 0.671


Train [19/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.633]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.553]


[headwise_nonlinear_tpa] Epoch 19/40 | train loss: 0.6476, train acc: 0.766 | val loss: 0.8984, val acc: 0.692


Train [20/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.533]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.723]


[headwise_nonlinear_tpa] Epoch 20/40 | train loss: 0.6235, train acc: 0.776 | val loss: 0.9177, val acc: 0.688


Train [21/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.733]
Val   [21/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.663]


[headwise_nonlinear_tpa] Epoch 21/40 | train loss: 0.5877, train acc: 0.790 | val loss: 0.8542, val acc: 0.704


Train [22/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.585]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.787]


[headwise_nonlinear_tpa] Epoch 22/40 | train loss: 0.5641, train acc: 0.797 | val loss: 0.8771, val acc: 0.704


Train [23/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.585]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.956]


[headwise_nonlinear_tpa] Epoch 23/40 | train loss: 0.5274, train acc: 0.812 | val loss: 0.9056, val acc: 0.700


Train [24/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.494]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.525]


[headwise_nonlinear_tpa] Epoch 24/40 | train loss: 0.4955, train acc: 0.823 | val loss: 0.8714, val acc: 0.708


Train [25/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.608]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.919]


[headwise_nonlinear_tpa] Epoch 25/40 | train loss: 0.4641, train acc: 0.834 | val loss: 0.9317, val acc: 0.703


Train [26/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.453]
Val   [26/40]: 100%|████████████████████████████████████████| 79/79 [00:19<00:00,  4.11it/s, loss=1]


[headwise_nonlinear_tpa] Epoch 26/40 | train loss: 0.4386, train acc: 0.843 | val loss: 0.9270, val acc: 0.708


Train [27/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.385]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.11it/s, loss=0.748]


[headwise_nonlinear_tpa] Epoch 27/40 | train loss: 0.4160, train acc: 0.850 | val loss: 0.9598, val acc: 0.700


Train [28/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.15it/s, loss=0.476]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.765]


[headwise_nonlinear_tpa] Epoch 28/40 | train loss: 0.3849, train acc: 0.862 | val loss: 0.9657, val acc: 0.702


Train [29/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.15it/s, loss=0.411]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.658]


[headwise_nonlinear_tpa] Epoch 29/40 | train loss: 0.3564, train acc: 0.871 | val loss: 0.9746, val acc: 0.712


Train [30/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.285]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.542]


[headwise_nonlinear_tpa] Epoch 30/40 | train loss: 0.3288, train acc: 0.881 | val loss: 1.0079, val acc: 0.708


Train [31/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.258]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.818]


[headwise_nonlinear_tpa] Epoch 31/40 | train loss: 0.3158, train acc: 0.886 | val loss: 1.0221, val acc: 0.706


Train [32/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.212]
Val   [32/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=1.01]


[headwise_nonlinear_tpa] Epoch 32/40 | train loss: 0.2867, train acc: 0.896 | val loss: 1.0441, val acc: 0.707


Train [33/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.213]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.608]


[headwise_nonlinear_tpa] Epoch 33/40 | train loss: 0.2573, train acc: 0.907 | val loss: 1.0613, val acc: 0.710


Train [34/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.312]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.963]


[headwise_nonlinear_tpa] Epoch 34/40 | train loss: 0.2441, train acc: 0.912 | val loss: 1.0572, val acc: 0.715


Train [35/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.345]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.944]


[headwise_nonlinear_tpa] Epoch 35/40 | train loss: 0.2169, train acc: 0.923 | val loss: 1.1233, val acc: 0.708


Train [36/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.119]
Val   [36/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.5]


[headwise_nonlinear_tpa] Epoch 36/40 | train loss: 0.2158, train acc: 0.921 | val loss: 1.0858, val acc: 0.709


Train [37/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.234]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.652]


[headwise_nonlinear_tpa] Epoch 37/40 | train loss: 0.1864, train acc: 0.934 | val loss: 1.1651, val acc: 0.702


Train [38/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.206]
Val   [38/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=0.873]


[headwise_nonlinear_tpa] Epoch 38/40 | train loss: 0.1851, train acc: 0.934 | val loss: 1.1877, val acc: 0.709


Train [39/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.236]
Val   [39/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.11it/s, loss=1.26]


[headwise_nonlinear_tpa] Epoch 39/40 | train loss: 0.1734, train acc: 0.938 | val loss: 1.2060, val acc: 0.707


Train [40/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.219]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.952]


[headwise_nonlinear_tpa] Epoch 40/40 | train loss: 0.1607, train acc: 0.942 | val loss: 1.1871, val acc: 0.709

-------------------- [cifar10] Experiment: NonlinearTPA_HW_KV --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=1.61]
Val   [1/40]: 100%|███████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=1.3]


[headwise_nonlinear_tpa] Epoch 1/40 | train loss: 1.8047, train acc: 0.319 | val loss: 1.5776, val acc: 0.416


Train [2/40]: 100%|████████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=1.38]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.29]


[headwise_nonlinear_tpa] Epoch 2/40 | train loss: 1.4813, train acc: 0.458 | val loss: 1.4083, val acc: 0.487


Train [3/40]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.51]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.06]


[headwise_nonlinear_tpa] Epoch 3/40 | train loss: 1.3431, train acc: 0.512 | val loss: 1.3057, val acc: 0.530


Train [4/40]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.42]
Val   [4/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.962]


[headwise_nonlinear_tpa] Epoch 4/40 | train loss: 1.2604, train acc: 0.545 | val loss: 1.2196, val acc: 0.564


Train [5/40]: 100%|████████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=1.17]
Val   [5/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.964]


[headwise_nonlinear_tpa] Epoch 5/40 | train loss: 1.1910, train acc: 0.569 | val loss: 1.1615, val acc: 0.589


Train [6/40]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=1.11]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.915]


[headwise_nonlinear_tpa] Epoch 6/40 | train loss: 1.1215, train acc: 0.596 | val loss: 1.1329, val acc: 0.588


Train [7/40]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=1.07]
Val   [7/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.756]


[headwise_nonlinear_tpa] Epoch 7/40 | train loss: 1.0645, train acc: 0.618 | val loss: 1.1127, val acc: 0.597


Train [8/40]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=1.17]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.872]


[headwise_nonlinear_tpa] Epoch 8/40 | train loss: 1.0019, train acc: 0.640 | val loss: 1.0513, val acc: 0.626


Train [9/40]: 100%|███████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.801]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.927]


[headwise_nonlinear_tpa] Epoch 9/40 | train loss: 0.9570, train acc: 0.657 | val loss: 1.0298, val acc: 0.631


Train [10/40]: 100%|██████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=0.914]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.824]


[headwise_nonlinear_tpa] Epoch 10/40 | train loss: 0.9118, train acc: 0.671 | val loss: 0.9737, val acc: 0.650


Train [11/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.994]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.727]


[headwise_nonlinear_tpa] Epoch 11/40 | train loss: 0.8704, train acc: 0.689 | val loss: 0.9590, val acc: 0.655


Train [12/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.804]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.614]


[headwise_nonlinear_tpa] Epoch 12/40 | train loss: 0.8298, train acc: 0.704 | val loss: 0.9246, val acc: 0.671


Train [13/40]: 100%|███████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.71]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.593]


[headwise_nonlinear_tpa] Epoch 13/40 | train loss: 0.7907, train acc: 0.717 | val loss: 0.9331, val acc: 0.666


Train [14/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.831]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.676]


[headwise_nonlinear_tpa] Epoch 14/40 | train loss: 0.7561, train acc: 0.729 | val loss: 0.9182, val acc: 0.680


Train [15/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.734]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.726]


[headwise_nonlinear_tpa] Epoch 15/40 | train loss: 0.7249, train acc: 0.741 | val loss: 0.8981, val acc: 0.683


Train [16/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.895]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.673]


[headwise_nonlinear_tpa] Epoch 16/40 | train loss: 0.6898, train acc: 0.754 | val loss: 0.9224, val acc: 0.679


Train [17/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.606]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.807]


[headwise_nonlinear_tpa] Epoch 17/40 | train loss: 0.6613, train acc: 0.764 | val loss: 0.8982, val acc: 0.684


Train [18/40]: 100%|██████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=0.598]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.712]


[headwise_nonlinear_tpa] Epoch 18/40 | train loss: 0.6273, train acc: 0.776 | val loss: 0.8923, val acc: 0.696


Train [19/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.476]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.855]


[headwise_nonlinear_tpa] Epoch 19/40 | train loss: 0.5894, train acc: 0.789 | val loss: 0.8813, val acc: 0.705


Train [20/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.396]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.347]


[headwise_nonlinear_tpa] Epoch 20/40 | train loss: 0.5604, train acc: 0.801 | val loss: 0.8592, val acc: 0.707


Train [21/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.595]
Val   [21/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.68]


[headwise_nonlinear_tpa] Epoch 21/40 | train loss: 0.5324, train acc: 0.810 | val loss: 0.8504, val acc: 0.716


Train [22/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.621]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.581]


[headwise_nonlinear_tpa] Epoch 22/40 | train loss: 0.4934, train acc: 0.823 | val loss: 0.8895, val acc: 0.706


Train [23/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.552]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.648]


[headwise_nonlinear_tpa] Epoch 23/40 | train loss: 0.4618, train acc: 0.834 | val loss: 0.9029, val acc: 0.706


Train [24/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.234]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.558]


[headwise_nonlinear_tpa] Epoch 24/40 | train loss: 0.4325, train acc: 0.846 | val loss: 0.9256, val acc: 0.708


Train [25/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.515]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.419]


[headwise_nonlinear_tpa] Epoch 25/40 | train loss: 0.4042, train acc: 0.856 | val loss: 0.9214, val acc: 0.710


Train [26/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.384]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.697]


[headwise_nonlinear_tpa] Epoch 26/40 | train loss: 0.3815, train acc: 0.864 | val loss: 0.9212, val acc: 0.718


Train [27/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.368]
Val   [27/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.48]


[headwise_nonlinear_tpa] Epoch 27/40 | train loss: 0.3450, train acc: 0.877 | val loss: 0.9642, val acc: 0.712


Train [28/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.417]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.664]


[headwise_nonlinear_tpa] Epoch 28/40 | train loss: 0.3149, train acc: 0.889 | val loss: 0.9683, val acc: 0.716


Train [29/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.221]
Val   [29/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.44]


[headwise_nonlinear_tpa] Epoch 29/40 | train loss: 0.3065, train acc: 0.891 | val loss: 0.9621, val acc: 0.717


Train [30/40]: 100%|██████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=0.442]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.502]


[headwise_nonlinear_tpa] Epoch 30/40 | train loss: 0.2744, train acc: 0.904 | val loss: 1.0403, val acc: 0.697


Train [31/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.364]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.259]


[headwise_nonlinear_tpa] Epoch 31/40 | train loss: 0.2565, train acc: 0.909 | val loss: 1.1146, val acc: 0.706


Train [32/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.185]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.674]


[headwise_nonlinear_tpa] Epoch 32/40 | train loss: 0.2368, train acc: 0.916 | val loss: 1.0615, val acc: 0.717


Train [33/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.258]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.375]


[headwise_nonlinear_tpa] Epoch 33/40 | train loss: 0.2197, train acc: 0.921 | val loss: 1.1039, val acc: 0.708


Train [34/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.235]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.539]


[headwise_nonlinear_tpa] Epoch 34/40 | train loss: 0.1977, train acc: 0.930 | val loss: 1.1549, val acc: 0.708


Train [35/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.292]
Val   [35/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.87]


[headwise_nonlinear_tpa] Epoch 35/40 | train loss: 0.1927, train acc: 0.932 | val loss: 1.1140, val acc: 0.719


Train [36/40]: 100%|███████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=0.05]
Val   [36/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.603]


[headwise_nonlinear_tpa] Epoch 36/40 | train loss: 0.1693, train acc: 0.939 | val loss: 1.1679, val acc: 0.711


Train [37/40]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.239]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.658]


[headwise_nonlinear_tpa] Epoch 37/40 | train loss: 0.1749, train acc: 0.937 | val loss: 1.1982, val acc: 0.710


Train [38/40]: 100%|██████████████████████████████████| 391/391 [02:58<00:00,  2.19it/s, loss=0.265]
Val   [38/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.713]


[headwise_nonlinear_tpa] Epoch 38/40 | train loss: 0.1550, train acc: 0.945 | val loss: 1.2343, val acc: 0.705


Train [39/40]: 100%|██████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=0.177]
Val   [39/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.654]


[headwise_nonlinear_tpa] Epoch 39/40 | train loss: 0.1474, train acc: 0.948 | val loss: 1.2036, val acc: 0.709


Train [40/40]: 100%|█████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=0.0997]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.684]


[headwise_nonlinear_tpa] Epoch 40/40 | train loss: 0.1423, train acc: 0.950 | val loss: 1.2200, val acc: 0.708

-------------------- [cifar10] Experiment: NonlinearTPA_HW_KV_shared --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.65]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=1.31]


[headwise_nonlinear_tpa] Epoch 1/40 | train loss: 1.8303, train acc: 0.312 | val loss: 1.6514, val acc: 0.381


Train [2/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.15it/s, loss=1.46]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.11]


[headwise_nonlinear_tpa] Epoch 2/40 | train loss: 1.5216, train acc: 0.441 | val loss: 1.4924, val acc: 0.452


Train [3/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.15it/s, loss=1.37]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=1.12]


[headwise_nonlinear_tpa] Epoch 3/40 | train loss: 1.3677, train acc: 0.503 | val loss: 1.3722, val acc: 0.507


Train [4/40]: 100%|████████████████████████████████████| 391/391 [03:01<00:00,  2.15it/s, loss=1.16]
Val   [4/40]: 100%|█████████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1]


[headwise_nonlinear_tpa] Epoch 4/40 | train loss: 1.2844, train acc: 0.535 | val loss: 1.2659, val acc: 0.552


Train [5/40]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.967]
Val   [5/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.853]


[headwise_nonlinear_tpa] Epoch 5/40 | train loss: 1.2019, train acc: 0.564 | val loss: 1.1843, val acc: 0.572


Train [6/40]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.25]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.844]


[headwise_nonlinear_tpa] Epoch 6/40 | train loss: 1.1360, train acc: 0.591 | val loss: 1.1166, val acc: 0.600


Train [7/40]: 100%|███████████████████████████████████| 391/391 [03:01<00:00,  2.15it/s, loss=0.957]
Val   [7/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=0.805]


[headwise_nonlinear_tpa] Epoch 7/40 | train loss: 1.0721, train acc: 0.613 | val loss: 1.0939, val acc: 0.613


Train [8/40]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.947]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.902]


[headwise_nonlinear_tpa] Epoch 8/40 | train loss: 1.0181, train acc: 0.634 | val loss: 1.0375, val acc: 0.627


Train [9/40]: 100%|███████████████████████████████████| 391/391 [03:01<00:00,  2.15it/s, loss=0.833]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=0.835]


[headwise_nonlinear_tpa] Epoch 9/40 | train loss: 0.9571, train acc: 0.657 | val loss: 1.0504, val acc: 0.625


Train [10/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.938]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.969]


[headwise_nonlinear_tpa] Epoch 10/40 | train loss: 0.9169, train acc: 0.671 | val loss: 1.0157, val acc: 0.641


Train [11/40]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=1.01]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.706]


[headwise_nonlinear_tpa] Epoch 11/40 | train loss: 0.8677, train acc: 0.689 | val loss: 0.9443, val acc: 0.665


Train [12/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.812]
Val   [12/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.63]


[headwise_nonlinear_tpa] Epoch 12/40 | train loss: 0.8263, train acc: 0.707 | val loss: 0.9112, val acc: 0.674


Train [13/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.801]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.606]


[headwise_nonlinear_tpa] Epoch 13/40 | train loss: 0.7894, train acc: 0.719 | val loss: 0.9008, val acc: 0.677


Train [14/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.901]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.727]


[headwise_nonlinear_tpa] Epoch 14/40 | train loss: 0.7457, train acc: 0.735 | val loss: 0.8925, val acc: 0.692


Train [15/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.655]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.581]


[headwise_nonlinear_tpa] Epoch 15/40 | train loss: 0.7158, train acc: 0.745 | val loss: 0.9304, val acc: 0.678


Train [16/40]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.85]
Val   [16/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=1.02]


[headwise_nonlinear_tpa] Epoch 16/40 | train loss: 0.6896, train acc: 0.755 | val loss: 0.9081, val acc: 0.687


Train [17/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.656]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.543]


[headwise_nonlinear_tpa] Epoch 17/40 | train loss: 0.6448, train acc: 0.769 | val loss: 0.8665, val acc: 0.698


Train [18/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.564]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=0.892]


[headwise_nonlinear_tpa] Epoch 18/40 | train loss: 0.6195, train acc: 0.779 | val loss: 0.8870, val acc: 0.697


Train [19/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.533]
Val   [19/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.09it/s, loss=0.75]


[headwise_nonlinear_tpa] Epoch 19/40 | train loss: 0.5954, train acc: 0.788 | val loss: 0.8653, val acc: 0.701


Train [20/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.494]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.619]


[headwise_nonlinear_tpa] Epoch 20/40 | train loss: 0.5522, train acc: 0.804 | val loss: 0.8658, val acc: 0.707


Train [21/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.447]
Val   [21/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.691]


[headwise_nonlinear_tpa] Epoch 21/40 | train loss: 0.5299, train acc: 0.810 | val loss: 0.9354, val acc: 0.693


Train [22/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.612]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.785]


[headwise_nonlinear_tpa] Epoch 22/40 | train loss: 0.4908, train acc: 0.824 | val loss: 0.8846, val acc: 0.714


Train [23/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.714]
Val   [23/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.92]


[headwise_nonlinear_tpa] Epoch 23/40 | train loss: 0.4641, train acc: 0.833 | val loss: 0.8890, val acc: 0.707


Train [24/40]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.49]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.492]


[headwise_nonlinear_tpa] Epoch 24/40 | train loss: 0.4341, train acc: 0.845 | val loss: 0.8922, val acc: 0.714


Train [25/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.547]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.755]


[headwise_nonlinear_tpa] Epoch 25/40 | train loss: 0.4024, train acc: 0.856 | val loss: 0.9549, val acc: 0.706


Train [26/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.387]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.711]


[headwise_nonlinear_tpa] Epoch 26/40 | train loss: 0.3780, train acc: 0.864 | val loss: 0.9379, val acc: 0.715


Train [27/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.441]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.869]


[headwise_nonlinear_tpa] Epoch 27/40 | train loss: 0.3489, train acc: 0.875 | val loss: 0.9976, val acc: 0.706


Train [28/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.283]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.783]


[headwise_nonlinear_tpa] Epoch 28/40 | train loss: 0.3324, train acc: 0.880 | val loss: 1.0133, val acc: 0.703


Train [29/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.414]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.868]


[headwise_nonlinear_tpa] Epoch 29/40 | train loss: 0.2979, train acc: 0.895 | val loss: 1.0169, val acc: 0.707


Train [30/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.324]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.705]


[headwise_nonlinear_tpa] Epoch 30/40 | train loss: 0.2733, train acc: 0.902 | val loss: 0.9914, val acc: 0.716


Train [31/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.339]
Val   [31/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.16]


[headwise_nonlinear_tpa] Epoch 31/40 | train loss: 0.2602, train acc: 0.907 | val loss: 1.0351, val acc: 0.707


Train [32/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.297]
Val   [32/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=1.31]


[headwise_nonlinear_tpa] Epoch 32/40 | train loss: 0.2373, train acc: 0.915 | val loss: 1.0838, val acc: 0.711


Train [33/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.277]
Val   [33/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=1.35]


[headwise_nonlinear_tpa] Epoch 33/40 | train loss: 0.2207, train acc: 0.921 | val loss: 1.1428, val acc: 0.711


Train [34/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.276]
Val   [34/40]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=1.32]


[headwise_nonlinear_tpa] Epoch 34/40 | train loss: 0.2088, train acc: 0.926 | val loss: 1.1269, val acc: 0.708


Train [35/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.268]
Val   [35/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=1.34]


[headwise_nonlinear_tpa] Epoch 35/40 | train loss: 0.1825, train acc: 0.935 | val loss: 1.1287, val acc: 0.711


Train [36/40]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=0.208]
Val   [36/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.918]


[headwise_nonlinear_tpa] Epoch 36/40 | train loss: 0.1786, train acc: 0.936 | val loss: 1.1939, val acc: 0.703


Train [37/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.193]
Val   [37/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.03]


[headwise_nonlinear_tpa] Epoch 37/40 | train loss: 0.1613, train acc: 0.943 | val loss: 1.2213, val acc: 0.710


Train [38/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.192]
Val   [38/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.19]


[headwise_nonlinear_tpa] Epoch 38/40 | train loss: 0.1612, train acc: 0.943 | val loss: 1.1845, val acc: 0.707


Train [39/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.199]
Val   [39/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.944]


[headwise_nonlinear_tpa] Epoch 39/40 | train loss: 0.1527, train acc: 0.946 | val loss: 1.2479, val acc: 0.710


Train [40/40]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.144]
Val   [40/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.09]


[headwise_nonlinear_tpa] Epoch 40/40 | train loss: 0.1404, train acc: 0.951 | val loss: 1.2340, val acc: 0.711

-------------------- [cifar10] Experiment: NonlinearTPA_QKV --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=1.64]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.55]


[nonlinear_tpa] Epoch 1/40 | train loss: 1.8285, train acc: 0.310 | val loss: 1.6434, val acc: 0.380


Train [2/40]: 100%|████████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=1.31]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.47]


[nonlinear_tpa] Epoch 2/40 | train loss: 1.5219, train acc: 0.443 | val loss: 1.4304, val acc: 0.486


Train [3/40]: 100%|█████████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=1.5]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.12]


[nonlinear_tpa] Epoch 3/40 | train loss: 1.3598, train acc: 0.506 | val loss: 1.3175, val acc: 0.521


Train [4/40]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=1.28]
Val   [4/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.953]


[nonlinear_tpa] Epoch 4/40 | train loss: 1.2625, train acc: 0.543 | val loss: 1.2354, val acc: 0.549


Train [5/40]: 100%|████████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=1.23]
Val   [5/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.15]


[nonlinear_tpa] Epoch 5/40 | train loss: 1.1865, train acc: 0.569 | val loss: 1.2382, val acc: 0.555


Train [6/40]: 100%|████████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=1.25]
Val   [6/40]: 100%|███████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.2]


[nonlinear_tpa] Epoch 6/40 | train loss: 1.1255, train acc: 0.592 | val loss: 1.1744, val acc: 0.584


Train [7/40]: 100%|████████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=1.25]
Val   [7/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.06]


[nonlinear_tpa] Epoch 7/40 | train loss: 1.0675, train acc: 0.615 | val loss: 1.1478, val acc: 0.591


Train [8/40]: 100%|███████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.887]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.908]


[nonlinear_tpa] Epoch 8/40 | train loss: 1.0235, train acc: 0.631 | val loss: 1.0595, val acc: 0.614


Train [9/40]: 100%|████████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=1.08]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.684]


[nonlinear_tpa] Epoch 9/40 | train loss: 0.9727, train acc: 0.651 | val loss: 1.0207, val acc: 0.633


Train [10/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.873]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.793]


[nonlinear_tpa] Epoch 10/40 | train loss: 0.9271, train acc: 0.668 | val loss: 1.0216, val acc: 0.635


Train [11/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.29it/s, loss=0.954]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=0.955]


[nonlinear_tpa] Epoch 11/40 | train loss: 0.8883, train acc: 0.681 | val loss: 0.9688, val acc: 0.653


Train [12/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.782]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.798]


[nonlinear_tpa] Epoch 12/40 | train loss: 0.8524, train acc: 0.694 | val loss: 0.9824, val acc: 0.651


Train [13/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.29it/s, loss=0.947]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.724]


[nonlinear_tpa] Epoch 13/40 | train loss: 0.8214, train acc: 0.704 | val loss: 0.9207, val acc: 0.674


Train [14/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.851]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.686]


[nonlinear_tpa] Epoch 14/40 | train loss: 0.7834, train acc: 0.719 | val loss: 0.8879, val acc: 0.689


Train [15/40]: 100%|███████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.74]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.787]


[nonlinear_tpa] Epoch 15/40 | train loss: 0.7492, train acc: 0.731 | val loss: 0.9301, val acc: 0.671


Train [16/40]: 100%|███████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.84]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.16it/s, loss=0.696]


[nonlinear_tpa] Epoch 16/40 | train loss: 0.7189, train acc: 0.742 | val loss: 0.8946, val acc: 0.683


Train [17/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.839]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.712]


[nonlinear_tpa] Epoch 17/40 | train loss: 0.6905, train acc: 0.752 | val loss: 0.9120, val acc: 0.679


Train [18/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.614]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.625]


[nonlinear_tpa] Epoch 18/40 | train loss: 0.6571, train acc: 0.764 | val loss: 0.8758, val acc: 0.691


Train [19/40]: 100%|████████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.6]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.581]


[nonlinear_tpa] Epoch 19/40 | train loss: 0.6226, train acc: 0.776 | val loss: 0.8534, val acc: 0.710


Train [20/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.29it/s, loss=0.415]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.693]


[nonlinear_tpa] Epoch 20/40 | train loss: 0.5906, train acc: 0.788 | val loss: 0.8578, val acc: 0.710


Train [21/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.527]
Val   [21/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.48]


[nonlinear_tpa] Epoch 21/40 | train loss: 0.5591, train acc: 0.801 | val loss: 0.8807, val acc: 0.709


Train [22/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.524]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.473]


[nonlinear_tpa] Epoch 22/40 | train loss: 0.5355, train acc: 0.808 | val loss: 0.8685, val acc: 0.706


Train [23/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.535]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.561]


[nonlinear_tpa] Epoch 23/40 | train loss: 0.4996, train acc: 0.820 | val loss: 0.8725, val acc: 0.716


Train [24/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.383]
Val   [24/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.77]


[nonlinear_tpa] Epoch 24/40 | train loss: 0.4714, train acc: 0.832 | val loss: 0.8579, val acc: 0.717


Train [25/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.559]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.527]


[nonlinear_tpa] Epoch 25/40 | train loss: 0.4428, train acc: 0.842 | val loss: 0.8676, val acc: 0.720


Train [26/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.584]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.514]


[nonlinear_tpa] Epoch 26/40 | train loss: 0.4130, train acc: 0.852 | val loss: 0.9378, val acc: 0.709


Train [27/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.351]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.604]


[nonlinear_tpa] Epoch 27/40 | train loss: 0.3842, train acc: 0.861 | val loss: 0.9087, val acc: 0.714


Train [28/40]: 100%|██████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=0.607]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.503]


[nonlinear_tpa] Epoch 28/40 | train loss: 0.3594, train acc: 0.871 | val loss: 0.9464, val acc: 0.705


Train [29/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.415]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.888]


[nonlinear_tpa] Epoch 29/40 | train loss: 0.3334, train acc: 0.879 | val loss: 0.9702, val acc: 0.717


Train [30/40]: 100%|██████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=0.229]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.16it/s, loss=0.788]


[nonlinear_tpa] Epoch 30/40 | train loss: 0.3020, train acc: 0.892 | val loss: 0.9902, val acc: 0.715


Train [31/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.253]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.802]


[nonlinear_tpa] Epoch 31/40 | train loss: 0.2872, train acc: 0.897 | val loss: 1.0038, val acc: 0.714


Train [32/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.315]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.764]


[nonlinear_tpa] Epoch 32/40 | train loss: 0.2675, train acc: 0.904 | val loss: 0.9794, val acc: 0.719


Train [33/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.256]
Val   [33/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=1.01]


[nonlinear_tpa] Epoch 33/40 | train loss: 0.2473, train acc: 0.911 | val loss: 1.0721, val acc: 0.709


Train [34/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.238]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.585]


[nonlinear_tpa] Epoch 34/40 | train loss: 0.2320, train acc: 0.917 | val loss: 1.0845, val acc: 0.720


Train [35/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.166]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=0.699]


[nonlinear_tpa] Epoch 35/40 | train loss: 0.2123, train acc: 0.925 | val loss: 1.0979, val acc: 0.718


Train [36/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.313]
Val   [36/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.446]


[nonlinear_tpa] Epoch 36/40 | train loss: 0.1894, train acc: 0.932 | val loss: 1.1062, val acc: 0.718


Train [37/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.179]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.348]


[nonlinear_tpa] Epoch 37/40 | train loss: 0.1854, train acc: 0.934 | val loss: 1.1454, val acc: 0.722


Train [38/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.131]
Val   [38/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=0.67]


[nonlinear_tpa] Epoch 38/40 | train loss: 0.1695, train acc: 0.940 | val loss: 1.2423, val acc: 0.703


Train [39/40]: 100%|███████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.22]
Val   [39/40]: 100%|████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=0.544]


[nonlinear_tpa] Epoch 39/40 | train loss: 0.1602, train acc: 0.944 | val loss: 1.1822, val acc: 0.715


Train [40/40]: 100%|██████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=0.207]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=0.423]


[nonlinear_tpa] Epoch 40/40 | train loss: 0.1531, train acc: 0.946 | val loss: 1.2086, val acc: 0.708

-------------------- [cifar10] Experiment: NonlinearTPA_Q --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.63]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=1.39]


[nonlinear_tpa] Epoch 1/40 | train loss: 1.7419, train acc: 0.352 | val loss: 1.5654, val acc: 0.430


Train [2/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.52]
Val   [2/40]: 100%|███████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=1.3]


[nonlinear_tpa] Epoch 2/40 | train loss: 1.4645, train acc: 0.466 | val loss: 1.4690, val acc: 0.460


Train [3/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.25]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.15]


[nonlinear_tpa] Epoch 3/40 | train loss: 1.3262, train acc: 0.520 | val loss: 1.2707, val acc: 0.550


Train [4/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.13]
Val   [4/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.14]


[nonlinear_tpa] Epoch 4/40 | train loss: 1.2187, train acc: 0.558 | val loss: 1.1938, val acc: 0.570


Train [5/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.903]
Val   [5/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=1.14]


[nonlinear_tpa] Epoch 5/40 | train loss: 1.1321, train acc: 0.595 | val loss: 1.1329, val acc: 0.592


Train [6/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.04]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.998]


[nonlinear_tpa] Epoch 6/40 | train loss: 1.0565, train acc: 0.620 | val loss: 1.0715, val acc: 0.617


Train [7/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.16]
Val   [7/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.906]


[nonlinear_tpa] Epoch 7/40 | train loss: 0.9899, train acc: 0.645 | val loss: 0.9984, val acc: 0.646


Train [8/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.871]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.948]


[nonlinear_tpa] Epoch 8/40 | train loss: 0.9326, train acc: 0.666 | val loss: 1.0032, val acc: 0.643


Train [9/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.787]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.755]


[nonlinear_tpa] Epoch 9/40 | train loss: 0.8821, train acc: 0.685 | val loss: 0.9815, val acc: 0.644


Train [10/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.01]
Val   [10/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.86]


[nonlinear_tpa] Epoch 10/40 | train loss: 0.8389, train acc: 0.700 | val loss: 0.9407, val acc: 0.671


Train [11/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.743]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.865]


[nonlinear_tpa] Epoch 11/40 | train loss: 0.7930, train acc: 0.716 | val loss: 0.9338, val acc: 0.675


Train [12/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.697]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.901]


[nonlinear_tpa] Epoch 12/40 | train loss: 0.7565, train acc: 0.731 | val loss: 0.9045, val acc: 0.680


Train [13/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.543]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.941]


[nonlinear_tpa] Epoch 13/40 | train loss: 0.7147, train acc: 0.744 | val loss: 0.8726, val acc: 0.694


Train [14/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.561]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.795]


[nonlinear_tpa] Epoch 14/40 | train loss: 0.6823, train acc: 0.757 | val loss: 0.8505, val acc: 0.699


Train [15/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.596]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.646]


[nonlinear_tpa] Epoch 15/40 | train loss: 0.6368, train acc: 0.773 | val loss: 0.8860, val acc: 0.693


Train [16/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.715]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=0.656]


[nonlinear_tpa] Epoch 16/40 | train loss: 0.5988, train acc: 0.787 | val loss: 0.8625, val acc: 0.704


Train [17/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.605]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.663]


[nonlinear_tpa] Epoch 17/40 | train loss: 0.5752, train acc: 0.795 | val loss: 0.8870, val acc: 0.699


Train [18/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.594]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.797]


[nonlinear_tpa] Epoch 18/40 | train loss: 0.5257, train acc: 0.814 | val loss: 0.9256, val acc: 0.693


Train [19/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.57]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.678]


[nonlinear_tpa] Epoch 19/40 | train loss: 0.4922, train acc: 0.825 | val loss: 0.9064, val acc: 0.707


Train [20/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.669]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.556]


[nonlinear_tpa] Epoch 20/40 | train loss: 0.4626, train acc: 0.834 | val loss: 0.9191, val acc: 0.703


Train [21/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.597]
Val   [21/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.474]


[nonlinear_tpa] Epoch 21/40 | train loss: 0.4215, train acc: 0.848 | val loss: 0.8896, val acc: 0.711


Train [22/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.489]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.866]


[nonlinear_tpa] Epoch 22/40 | train loss: 0.3876, train acc: 0.862 | val loss: 0.9435, val acc: 0.710


Train [23/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.375]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.683]


[nonlinear_tpa] Epoch 23/40 | train loss: 0.3495, train acc: 0.875 | val loss: 0.9867, val acc: 0.706


Train [24/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.474]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.598]


[nonlinear_tpa] Epoch 24/40 | train loss: 0.3198, train acc: 0.884 | val loss: 0.9954, val acc: 0.705


Train [25/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.388]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.776]


[nonlinear_tpa] Epoch 25/40 | train loss: 0.2921, train acc: 0.896 | val loss: 1.0467, val acc: 0.708


Train [26/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.332]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.915]


[nonlinear_tpa] Epoch 26/40 | train loss: 0.2578, train acc: 0.908 | val loss: 1.0559, val acc: 0.714


Train [27/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.311]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.578]


[nonlinear_tpa] Epoch 27/40 | train loss: 0.2509, train acc: 0.910 | val loss: 1.0673, val acc: 0.708


Train [28/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.203]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.669]


[nonlinear_tpa] Epoch 28/40 | train loss: 0.2252, train acc: 0.919 | val loss: 1.1076, val acc: 0.701


Train [29/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.355]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.839]


[nonlinear_tpa] Epoch 29/40 | train loss: 0.2020, train acc: 0.927 | val loss: 1.1364, val acc: 0.711


Train [30/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.182]
Val   [30/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.488]


[nonlinear_tpa] Epoch 30/40 | train loss: 0.1872, train acc: 0.933 | val loss: 1.1405, val acc: 0.714


Train [31/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.303]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.537]


[nonlinear_tpa] Epoch 31/40 | train loss: 0.1679, train acc: 0.941 | val loss: 1.2045, val acc: 0.708


Train [32/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.338]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.695]


[nonlinear_tpa] Epoch 32/40 | train loss: 0.1628, train acc: 0.943 | val loss: 1.2396, val acc: 0.708


Train [33/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.181]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.561]


[nonlinear_tpa] Epoch 33/40 | train loss: 0.1558, train acc: 0.944 | val loss: 1.2523, val acc: 0.706


Train [34/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.186]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.848]


[nonlinear_tpa] Epoch 34/40 | train loss: 0.1518, train acc: 0.948 | val loss: 1.2100, val acc: 0.708


Train [35/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.112]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.933]


[nonlinear_tpa] Epoch 35/40 | train loss: 0.1280, train acc: 0.955 | val loss: 1.2418, val acc: 0.712


Train [36/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.228]
Val   [36/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.23]


[nonlinear_tpa] Epoch 36/40 | train loss: 0.1276, train acc: 0.956 | val loss: 1.2913, val acc: 0.700


Train [37/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.109]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.603]


[nonlinear_tpa] Epoch 37/40 | train loss: 0.1123, train acc: 0.960 | val loss: 1.3115, val acc: 0.715


Train [38/40]: 100%|█████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.0777]
Val   [38/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.443]


[nonlinear_tpa] Epoch 38/40 | train loss: 0.1221, train acc: 0.956 | val loss: 1.2975, val acc: 0.704


Train [39/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.15]
Val   [39/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.835]


[nonlinear_tpa] Epoch 39/40 | train loss: 0.1096, train acc: 0.961 | val loss: 1.3367, val acc: 0.711


Train [40/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.171]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.766]


[nonlinear_tpa] Epoch 40/40 | train loss: 0.1151, train acc: 0.959 | val loss: 1.3505, val acc: 0.714

-------------------- [cifar10] Experiment: NonlinearTPA_K --------------------


Train [1/40]: 100%|█████████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=1.4]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.45]


[nonlinear_tpa] Epoch 1/40 | train loss: 1.7571, train acc: 0.342 | val loss: 1.5801, val acc: 0.423


Train [2/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=1.48]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.27]


[nonlinear_tpa] Epoch 2/40 | train loss: 1.4870, train acc: 0.457 | val loss: 1.4041, val acc: 0.493


Train [3/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=1.48]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.15]


[nonlinear_tpa] Epoch 3/40 | train loss: 1.3467, train acc: 0.510 | val loss: 1.3414, val acc: 0.517


Train [4/40]: 100%|████████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=1.21]
Val   [4/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.29]


[nonlinear_tpa] Epoch 4/40 | train loss: 1.2489, train acc: 0.548 | val loss: 1.2673, val acc: 0.543


Train [5/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.889]
Val   [5/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.16]


[nonlinear_tpa] Epoch 5/40 | train loss: 1.1735, train acc: 0.576 | val loss: 1.1984, val acc: 0.568


Train [6/40]: 100%|███████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.927]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.847]


[nonlinear_tpa] Epoch 6/40 | train loss: 1.0957, train acc: 0.607 | val loss: 1.0985, val acc: 0.612


Train [7/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.879]
Val   [7/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.06]


[nonlinear_tpa] Epoch 7/40 | train loss: 1.0318, train acc: 0.631 | val loss: 1.0854, val acc: 0.610


Train [8/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.88]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.807]


[nonlinear_tpa] Epoch 8/40 | train loss: 0.9759, train acc: 0.650 | val loss: 1.0129, val acc: 0.639


Train [9/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.885]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.675]


[nonlinear_tpa] Epoch 9/40 | train loss: 0.9295, train acc: 0.668 | val loss: 0.9725, val acc: 0.650


Train [10/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.92]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.948]


[nonlinear_tpa] Epoch 10/40 | train loss: 0.8892, train acc: 0.682 | val loss: 0.9723, val acc: 0.648


Train [11/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.789]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.679]


[nonlinear_tpa] Epoch 11/40 | train loss: 0.8419, train acc: 0.697 | val loss: 0.9567, val acc: 0.662


Train [12/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.709]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.649]


[nonlinear_tpa] Epoch 12/40 | train loss: 0.8000, train acc: 0.715 | val loss: 0.9160, val acc: 0.680


Train [13/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.601]
Val   [13/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.68]


[nonlinear_tpa] Epoch 13/40 | train loss: 0.7631, train acc: 0.729 | val loss: 0.9181, val acc: 0.683


Train [14/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.608]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.739]


[nonlinear_tpa] Epoch 14/40 | train loss: 0.7271, train acc: 0.741 | val loss: 0.9330, val acc: 0.678


Train [15/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.831]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.678]


[nonlinear_tpa] Epoch 15/40 | train loss: 0.6923, train acc: 0.753 | val loss: 0.9308, val acc: 0.675


Train [16/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.852]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.825]


[nonlinear_tpa] Epoch 16/40 | train loss: 0.6617, train acc: 0.765 | val loss: 0.8944, val acc: 0.690


Train [17/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.508]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.483]


[nonlinear_tpa] Epoch 17/40 | train loss: 0.6188, train acc: 0.779 | val loss: 0.9016, val acc: 0.693


Train [18/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.574]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.603]


[nonlinear_tpa] Epoch 18/40 | train loss: 0.5920, train acc: 0.790 | val loss: 0.8678, val acc: 0.702


Train [19/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.565]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.728]


[nonlinear_tpa] Epoch 19/40 | train loss: 0.5599, train acc: 0.799 | val loss: 0.8353, val acc: 0.716


Train [20/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.691]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.636]


[nonlinear_tpa] Epoch 20/40 | train loss: 0.5256, train acc: 0.812 | val loss: 0.8885, val acc: 0.705


Train [21/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.495]
Val   [21/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.575]


[nonlinear_tpa] Epoch 21/40 | train loss: 0.4912, train acc: 0.826 | val loss: 0.8743, val acc: 0.708


Train [22/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.406]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.499]


[nonlinear_tpa] Epoch 22/40 | train loss: 0.4619, train acc: 0.835 | val loss: 0.8908, val acc: 0.712


Train [23/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.506]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.749]


[nonlinear_tpa] Epoch 23/40 | train loss: 0.4323, train acc: 0.847 | val loss: 0.9520, val acc: 0.698


Train [24/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.421]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.537]


[nonlinear_tpa] Epoch 24/40 | train loss: 0.4078, train acc: 0.854 | val loss: 0.9356, val acc: 0.708


Train [25/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.472]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.892]


[nonlinear_tpa] Epoch 25/40 | train loss: 0.3745, train acc: 0.865 | val loss: 0.9272, val acc: 0.710


Train [26/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.415]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.567]


[nonlinear_tpa] Epoch 26/40 | train loss: 0.3388, train acc: 0.880 | val loss: 0.9827, val acc: 0.707


Train [27/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.359]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.675]


[nonlinear_tpa] Epoch 27/40 | train loss: 0.3206, train acc: 0.885 | val loss: 0.9940, val acc: 0.714


Train [28/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.381]
Val   [28/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.45]


[nonlinear_tpa] Epoch 28/40 | train loss: 0.2939, train acc: 0.894 | val loss: 0.9942, val acc: 0.712


Train [29/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.346]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.579]


[nonlinear_tpa] Epoch 29/40 | train loss: 0.2655, train acc: 0.905 | val loss: 1.0615, val acc: 0.712


Train [30/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.218]
Val   [30/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.44]


[nonlinear_tpa] Epoch 30/40 | train loss: 0.2491, train acc: 0.912 | val loss: 1.0272, val acc: 0.712


Train [31/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.174]
Val   [31/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.512]


[nonlinear_tpa] Epoch 31/40 | train loss: 0.2344, train acc: 0.917 | val loss: 1.0881, val acc: 0.710


Train [32/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.325]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.913]


[nonlinear_tpa] Epoch 32/40 | train loss: 0.2094, train acc: 0.925 | val loss: 1.1318, val acc: 0.712


Train [33/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.291]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.937]


[nonlinear_tpa] Epoch 33/40 | train loss: 0.1949, train acc: 0.931 | val loss: 1.1503, val acc: 0.712


Train [34/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.29]
Val   [34/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.914]


[nonlinear_tpa] Epoch 34/40 | train loss: 0.1856, train acc: 0.934 | val loss: 1.1875, val acc: 0.713


Train [35/40]: 100%|██████████████████████████████████| 391/391 [02:47<00:00,  2.34it/s, loss=0.178]
Val   [35/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.755]


[nonlinear_tpa] Epoch 35/40 | train loss: 0.1752, train acc: 0.936 | val loss: 1.1923, val acc: 0.708


Train [36/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.214]
Val   [36/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.614]


[nonlinear_tpa] Epoch 36/40 | train loss: 0.1586, train acc: 0.944 | val loss: 1.1979, val acc: 0.717


Train [37/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.142]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.319]


[nonlinear_tpa] Epoch 37/40 | train loss: 0.1516, train acc: 0.947 | val loss: 1.2514, val acc: 0.703


Train [38/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.132]
Val   [38/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.831]


[nonlinear_tpa] Epoch 38/40 | train loss: 0.1506, train acc: 0.948 | val loss: 1.2601, val acc: 0.709


Train [39/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.202]
Val   [39/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.05]


[nonlinear_tpa] Epoch 39/40 | train loss: 0.1368, train acc: 0.951 | val loss: 1.2262, val acc: 0.712


Train [40/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.237]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.601]


[nonlinear_tpa] Epoch 40/40 | train loss: 0.1355, train acc: 0.952 | val loss: 1.2645, val acc: 0.706

-------------------- [cifar10] Experiment: NonlinearTPA_V --------------------


Train [1/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.62]
Val   [1/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.27]


[nonlinear_tpa] Epoch 1/40 | train loss: 1.7810, train acc: 0.332 | val loss: 1.5638, val acc: 0.419


Train [2/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.51]
Val   [2/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.34]


[nonlinear_tpa] Epoch 2/40 | train loss: 1.4755, train acc: 0.460 | val loss: 1.4235, val acc: 0.491


Train [3/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.27]
Val   [3/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=1.21]


[nonlinear_tpa] Epoch 3/40 | train loss: 1.3512, train acc: 0.510 | val loss: 1.3180, val acc: 0.524


Train [4/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.16]
Val   [4/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.19]


[nonlinear_tpa] Epoch 4/40 | train loss: 1.2605, train acc: 0.545 | val loss: 1.2404, val acc: 0.554


Train [5/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.35]
Val   [5/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.703]


[nonlinear_tpa] Epoch 5/40 | train loss: 1.1717, train acc: 0.581 | val loss: 1.1794, val acc: 0.572


Train [6/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.07]
Val   [6/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.934]


[nonlinear_tpa] Epoch 6/40 | train loss: 1.0996, train acc: 0.605 | val loss: 1.1207, val acc: 0.602


Train [7/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.09]
Val   [7/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.731]


[nonlinear_tpa] Epoch 7/40 | train loss: 1.0419, train acc: 0.625 | val loss: 1.0770, val acc: 0.616


Train [8/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.09]
Val   [8/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.737]


[nonlinear_tpa] Epoch 8/40 | train loss: 0.9827, train acc: 0.648 | val loss: 1.0224, val acc: 0.631


Train [9/40]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.02]
Val   [9/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.781]


[nonlinear_tpa] Epoch 9/40 | train loss: 0.9370, train acc: 0.665 | val loss: 0.9962, val acc: 0.641


Train [10/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.03]
Val   [10/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.657]


[nonlinear_tpa] Epoch 10/40 | train loss: 0.8938, train acc: 0.680 | val loss: 0.9613, val acc: 0.656


Train [11/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.95]
Val   [11/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.775]


[nonlinear_tpa] Epoch 11/40 | train loss: 0.8500, train acc: 0.696 | val loss: 0.9365, val acc: 0.665


Train [12/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.718]
Val   [12/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.927]


[nonlinear_tpa] Epoch 12/40 | train loss: 0.8083, train acc: 0.711 | val loss: 0.9381, val acc: 0.661


Train [13/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.766]
Val   [13/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.689]


[nonlinear_tpa] Epoch 13/40 | train loss: 0.7829, train acc: 0.721 | val loss: 0.9397, val acc: 0.663


Train [14/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.564]
Val   [14/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=0.798]


[nonlinear_tpa] Epoch 14/40 | train loss: 0.7428, train acc: 0.735 | val loss: 0.9009, val acc: 0.678


Train [15/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.632]
Val   [15/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=0.653]


[nonlinear_tpa] Epoch 15/40 | train loss: 0.7078, train acc: 0.748 | val loss: 0.9300, val acc: 0.674


Train [16/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.73]
Val   [16/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.441]


[nonlinear_tpa] Epoch 16/40 | train loss: 0.6730, train acc: 0.761 | val loss: 0.8906, val acc: 0.687


Train [17/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.662]
Val   [17/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.833]


[nonlinear_tpa] Epoch 17/40 | train loss: 0.6454, train acc: 0.768 | val loss: 0.9293, val acc: 0.678


Train [18/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.761]
Val   [18/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.646]


[nonlinear_tpa] Epoch 18/40 | train loss: 0.6077, train acc: 0.784 | val loss: 0.8909, val acc: 0.700


Train [19/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.545]
Val   [19/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.961]


[nonlinear_tpa] Epoch 19/40 | train loss: 0.5772, train acc: 0.793 | val loss: 0.9148, val acc: 0.692


Train [20/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=0.536]
Val   [20/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.687]


[nonlinear_tpa] Epoch 20/40 | train loss: 0.5450, train acc: 0.807 | val loss: 0.9206, val acc: 0.698


Train [21/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.374]
Val   [21/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.882]


[nonlinear_tpa] Epoch 21/40 | train loss: 0.5099, train acc: 0.817 | val loss: 0.8887, val acc: 0.709


Train [22/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.57]
Val   [22/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.584]


[nonlinear_tpa] Epoch 22/40 | train loss: 0.4769, train acc: 0.829 | val loss: 0.9055, val acc: 0.709


Train [23/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.574]
Val   [23/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.679]


[nonlinear_tpa] Epoch 23/40 | train loss: 0.4446, train acc: 0.842 | val loss: 0.9422, val acc: 0.701


Train [24/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.403]
Val   [24/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.604]


[nonlinear_tpa] Epoch 24/40 | train loss: 0.4114, train acc: 0.855 | val loss: 0.9165, val acc: 0.713


Train [25/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.493]
Val   [25/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.786]


[nonlinear_tpa] Epoch 25/40 | train loss: 0.3773, train acc: 0.868 | val loss: 0.9708, val acc: 0.701


Train [26/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.437]
Val   [26/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.627]


[nonlinear_tpa] Epoch 26/40 | train loss: 0.3532, train acc: 0.873 | val loss: 0.9631, val acc: 0.707


Train [27/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.392]
Val   [27/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.778]


[nonlinear_tpa] Epoch 27/40 | train loss: 0.3319, train acc: 0.882 | val loss: 1.0065, val acc: 0.703


Train [28/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.385]
Val   [28/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=0.753]


[nonlinear_tpa] Epoch 28/40 | train loss: 0.3034, train acc: 0.891 | val loss: 1.0494, val acc: 0.702


Train [29/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.16]
Val   [29/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=0.638]


[nonlinear_tpa] Epoch 29/40 | train loss: 0.2797, train acc: 0.899 | val loss: 1.0692, val acc: 0.705


Train [30/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.37]
Val   [30/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.19]


[nonlinear_tpa] Epoch 30/40 | train loss: 0.2650, train acc: 0.906 | val loss: 1.0764, val acc: 0.707


Train [31/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.359]
Val   [31/40]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.7]


[nonlinear_tpa] Epoch 31/40 | train loss: 0.2378, train acc: 0.914 | val loss: 1.1389, val acc: 0.706


Train [32/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.354]
Val   [32/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.918]


[nonlinear_tpa] Epoch 32/40 | train loss: 0.2240, train acc: 0.920 | val loss: 1.1642, val acc: 0.696


Train [33/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.186]
Val   [33/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.901]


[nonlinear_tpa] Epoch 33/40 | train loss: 0.2132, train acc: 0.924 | val loss: 1.1590, val acc: 0.708


Train [34/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.251]
Val   [34/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.22]


[nonlinear_tpa] Epoch 34/40 | train loss: 0.1890, train acc: 0.933 | val loss: 1.2036, val acc: 0.711


Train [35/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.259]
Val   [35/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.25]


[nonlinear_tpa] Epoch 35/40 | train loss: 0.1892, train acc: 0.933 | val loss: 1.2215, val acc: 0.704


Train [36/40]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.16]
Val   [36/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.21]


[nonlinear_tpa] Epoch 36/40 | train loss: 0.1682, train acc: 0.941 | val loss: 1.2379, val acc: 0.705


Train [37/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.192]
Val   [37/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.893]


[nonlinear_tpa] Epoch 37/40 | train loss: 0.1653, train acc: 0.941 | val loss: 1.2614, val acc: 0.694


Train [38/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.141]
Val   [38/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.62]


[nonlinear_tpa] Epoch 38/40 | train loss: 0.1565, train acc: 0.944 | val loss: 1.2733, val acc: 0.706


Train [39/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.219]
Val   [39/40]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=1.66]


[nonlinear_tpa] Epoch 39/40 | train loss: 0.1381, train acc: 0.951 | val loss: 1.3532, val acc: 0.703


Train [40/40]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.168]
Val   [40/40]: 100%|████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=0.683]


[nonlinear_tpa] Epoch 40/40 | train loss: 0.1405, train acc: 0.949 | val loss: 1.2685, val acc: 0.707


[Info] Resolved MHA attn_type = mha

-------------------- [cifar100] Experiment: MHA_baseline --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=3.47]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.68it/s, loss=3.62]


[mha] Epoch 1/20 | train loss: 4.0273, train acc: 0.080 | val loss: 3.7114, val acc: 0.124


Train [2/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=3.38]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.58it/s, loss=3.34]


[mha] Epoch 2/20 | train loss: 3.5347, train acc: 0.154 | val loss: 3.3815, val acc: 0.186


Train [3/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=3.21]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.69it/s, loss=3.16]


[mha] Epoch 3/20 | train loss: 3.2327, train acc: 0.207 | val loss: 3.1042, val acc: 0.235


Train [4/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=2.82]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.68it/s, loss=2.74]


[mha] Epoch 4/20 | train loss: 2.9714, train acc: 0.257 | val loss: 2.9345, val acc: 0.270


Train [5/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.30it/s, loss=2.83]
Val   [5/20]: 100%|███████████████████████████████████████| 79/79 [00:10<00:00,  7.67it/s, loss=2.7]


[mha] Epoch 5/20 | train loss: 2.7556, train acc: 0.298 | val loss: 2.7184, val acc: 0.304


Train [6/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=2.59]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.66it/s, loss=2.53]


[mha] Epoch 6/20 | train loss: 2.5914, train acc: 0.332 | val loss: 2.6115, val acc: 0.334


Train [7/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=2.24]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.62it/s, loss=2.61]


[mha] Epoch 7/20 | train loss: 2.4410, train acc: 0.364 | val loss: 2.5188, val acc: 0.347


Train [8/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=2.48]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.67it/s, loss=2.46]


[mha] Epoch 8/20 | train loss: 2.2989, train acc: 0.394 | val loss: 2.3738, val acc: 0.385


Train [9/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.99]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.63it/s, loss=2.13]


[mha] Epoch 9/20 | train loss: 2.1806, train acc: 0.419 | val loss: 2.3736, val acc: 0.384


Train [10/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.81]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.67it/s, loss=2.17]


[mha] Epoch 10/20 | train loss: 2.0611, train acc: 0.444 | val loss: 2.3174, val acc: 0.399


Train [11/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.95]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.61it/s, loss=1.98]


[mha] Epoch 11/20 | train loss: 1.9445, train acc: 0.469 | val loss: 2.2967, val acc: 0.400


Train [12/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.73]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.69it/s, loss=2.12]


[mha] Epoch 12/20 | train loss: 1.8078, train acc: 0.502 | val loss: 2.2254, val acc: 0.426


Train [13/20]: 100%|████████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.5]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.68it/s, loss=1.77]


[mha] Epoch 13/20 | train loss: 1.6770, train acc: 0.533 | val loss: 2.2249, val acc: 0.415


Train [14/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.48]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.66it/s, loss=1.92]


[mha] Epoch 14/20 | train loss: 1.5501, train acc: 0.562 | val loss: 2.2213, val acc: 0.431


Train [15/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.28it/s, loss=1.39]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.67it/s, loss=1.66]


[mha] Epoch 15/20 | train loss: 1.4082, train acc: 0.597 | val loss: 2.2593, val acc: 0.430


Train [16/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.55]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.66it/s, loss=1.57]


[mha] Epoch 16/20 | train loss: 1.2562, train acc: 0.635 | val loss: 2.2198, val acc: 0.454


Train [17/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.14]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.68it/s, loss=2.02]


[mha] Epoch 17/20 | train loss: 1.0853, train acc: 0.679 | val loss: 2.3288, val acc: 0.435


Train [18/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.09]
Val   [18/20]: 100%|██████████████████████████████████████| 79/79 [00:10<00:00,  7.66it/s, loss=1.7]


[mha] Epoch 18/20 | train loss: 0.9332, train acc: 0.724 | val loss: 2.3765, val acc: 0.437


Train [19/20]: 100%|███████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=1.28]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:10<00:00,  7.64it/s, loss=1.47]


[mha] Epoch 19/20 | train loss: 0.7680, train acc: 0.771 | val loss: 2.3989, val acc: 0.447


Train [20/20]: 100%|██████████████████████████████████| 391/391 [01:31<00:00,  4.29it/s, loss=0.541]
Val   [20/20]: 100%|████████████████████████████████████████| 79/79 [00:10<00:00,  7.67it/s, loss=2]


[mha] Epoch 20/20 | train loss: 0.6016, train acc: 0.819 | val loss: 2.5758, val acc: 0.434

-------------------- [cifar100] Experiment: TPA_r1622 --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.62it/s, loss=3.63]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=3.49]


[tpa] Epoch 1/20 | train loss: 3.8945, train acc: 0.106 | val loss: 3.5531, val acc: 0.160


Train [2/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=3.21]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=3.32]


[tpa] Epoch 2/20 | train loss: 3.3647, train acc: 0.190 | val loss: 3.2058, val acc: 0.223


Train [3/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.89]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=2.79]


[tpa] Epoch 3/20 | train loss: 3.0654, train acc: 0.243 | val loss: 2.9772, val acc: 0.272


Train [4/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.85]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=2.51]


[tpa] Epoch 4/20 | train loss: 2.8446, train acc: 0.283 | val loss: 2.8082, val acc: 0.300


Train [5/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.53]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.53it/s, loss=2.29]


[tpa] Epoch 5/20 | train loss: 2.6346, train acc: 0.325 | val loss: 2.6767, val acc: 0.319


Train [6/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.41]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=2.35]


[tpa] Epoch 6/20 | train loss: 2.4680, train acc: 0.358 | val loss: 2.5850, val acc: 0.345


Train [7/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.42]
Val   [7/20]: 100%|███████████████████████████████████████| 79/79 [00:17<00:00,  4.53it/s, loss=2.2]


[tpa] Epoch 7/20 | train loss: 2.3150, train acc: 0.389 | val loss: 2.5131, val acc: 0.357


Train [8/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.15]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.54it/s, loss=2.41]


[tpa] Epoch 8/20 | train loss: 2.1639, train acc: 0.423 | val loss: 2.4909, val acc: 0.367


Train [9/20]: 100%|████████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.21]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.53it/s, loss=1.88]


[tpa] Epoch 9/20 | train loss: 2.0260, train acc: 0.451 | val loss: 2.4039, val acc: 0.385


Train [10/20]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.94]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=2.46]


[tpa] Epoch 10/20 | train loss: 1.8809, train acc: 0.487 | val loss: 2.4492, val acc: 0.381


Train [11/20]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=2.21]
Val   [11/20]: 100%|██████████████████████████████████████| 79/79 [00:17<00:00,  4.53it/s, loss=1.9]


[tpa] Epoch 11/20 | train loss: 1.7238, train acc: 0.522 | val loss: 2.5022, val acc: 0.380


Train [12/20]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.86]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=1.94]


[tpa] Epoch 12/20 | train loss: 1.5666, train acc: 0.557 | val loss: 2.4439, val acc: 0.388


Train [13/20]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.44]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=2.21]


[tpa] Epoch 13/20 | train loss: 1.3834, train acc: 0.608 | val loss: 2.4730, val acc: 0.400


Train [14/20]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.16]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.52it/s, loss=2.43]


[tpa] Epoch 14/20 | train loss: 1.2049, train acc: 0.653 | val loss: 2.5521, val acc: 0.392


Train [15/20]: 100%|███████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.21]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=2.15]


[tpa] Epoch 15/20 | train loss: 1.0423, train acc: 0.693 | val loss: 2.7073, val acc: 0.381


Train [16/20]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.793]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=2.11]


[tpa] Epoch 16/20 | train loss: 0.8576, train acc: 0.748 | val loss: 2.7269, val acc: 0.389


Train [17/20]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.827]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=2.25]


[tpa] Epoch 17/20 | train loss: 0.6977, train acc: 0.792 | val loss: 2.8811, val acc: 0.386


Train [18/20]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.712]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.50it/s, loss=2.78]


[tpa] Epoch 18/20 | train loss: 0.5556, train acc: 0.838 | val loss: 3.0665, val acc: 0.374


Train [19/20]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.548]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=2.36]


[tpa] Epoch 19/20 | train loss: 0.4285, train acc: 0.874 | val loss: 3.1458, val acc: 0.385


Train [20/20]: 100%|██████████████████████████████████| 391/391 [02:29<00:00,  2.61it/s, loss=0.338]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:17<00:00,  4.51it/s, loss=3.07]


[tpa] Epoch 20/20 | train loss: 0.3548, train acc: 0.895 | val loss: 3.3161, val acc: 0.381

-------------------- [cifar100] Experiment: NonlinearTPA_KV --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=3.55]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=3.61]


[nonlinear_tpa] Epoch 1/20 | train loss: 3.9854, train acc: 0.089 | val loss: 3.6425, val acc: 0.136


Train [2/20]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=3.47]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=3.33]


[nonlinear_tpa] Epoch 2/20 | train loss: 3.4576, train acc: 0.169 | val loss: 3.3508, val acc: 0.196


Train [3/20]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=3.15]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=3.11]


[nonlinear_tpa] Epoch 3/20 | train loss: 3.1766, train acc: 0.219 | val loss: 3.0874, val acc: 0.241


Train [4/20]: 100%|█████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=2.9]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.66]


[nonlinear_tpa] Epoch 4/20 | train loss: 2.9211, train acc: 0.268 | val loss: 2.8589, val acc: 0.287


Train [5/20]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=2.74]
Val   [5/20]: 100%|███████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.4]


[nonlinear_tpa] Epoch 5/20 | train loss: 2.7185, train acc: 0.307 | val loss: 2.7381, val acc: 0.313


Train [6/20]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=2.54]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.46]


[nonlinear_tpa] Epoch 6/20 | train loss: 2.5556, train acc: 0.340 | val loss: 2.5971, val acc: 0.335


Train [7/20]: 100%|█████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=2.4]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.38]


[nonlinear_tpa] Epoch 7/20 | train loss: 2.4234, train acc: 0.368 | val loss: 2.5251, val acc: 0.354


Train [8/20]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=2.06]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.31]


[nonlinear_tpa] Epoch 8/20 | train loss: 2.2981, train acc: 0.392 | val loss: 2.4188, val acc: 0.376


Train [9/20]: 100%|████████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.95]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.26]


[nonlinear_tpa] Epoch 9/20 | train loss: 2.1660, train acc: 0.419 | val loss: 2.3876, val acc: 0.383


Train [10/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.33it/s, loss=1.83]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.96]


[nonlinear_tpa] Epoch 10/20 | train loss: 2.0566, train acc: 0.443 | val loss: 2.3409, val acc: 0.388


Train [11/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.94]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.92]


[nonlinear_tpa] Epoch 11/20 | train loss: 1.9356, train acc: 0.473 | val loss: 2.3416, val acc: 0.392


Train [12/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=2.14]
Val   [12/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.1]


[nonlinear_tpa] Epoch 12/20 | train loss: 1.8187, train acc: 0.499 | val loss: 2.3110, val acc: 0.409


Train [13/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.58]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.12]


[nonlinear_tpa] Epoch 13/20 | train loss: 1.6856, train acc: 0.530 | val loss: 2.3024, val acc: 0.411


Train [14/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.43]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=1.74]


[nonlinear_tpa] Epoch 14/20 | train loss: 1.5607, train acc: 0.559 | val loss: 2.3097, val acc: 0.417


Train [15/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.63]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.61]


[nonlinear_tpa] Epoch 15/20 | train loss: 1.4049, train acc: 0.601 | val loss: 2.3671, val acc: 0.410


Train [16/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.33]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.06]


[nonlinear_tpa] Epoch 16/20 | train loss: 1.2476, train acc: 0.638 | val loss: 2.4142, val acc: 0.416


Train [17/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.32]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.14]


[nonlinear_tpa] Epoch 17/20 | train loss: 1.0944, train acc: 0.678 | val loss: 2.4496, val acc: 0.417


Train [18/20]: 100%|███████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=1.09]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=1.82]


[nonlinear_tpa] Epoch 18/20 | train loss: 0.9202, train acc: 0.727 | val loss: 2.5886, val acc: 0.408


Train [19/20]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.918]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.68]


[nonlinear_tpa] Epoch 19/20 | train loss: 0.7734, train acc: 0.767 | val loss: 2.6579, val acc: 0.409


Train [20/20]: 100%|██████████████████████████████████| 391/391 [02:48<00:00,  2.32it/s, loss=0.582]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.08]


[nonlinear_tpa] Epoch 20/20 | train loss: 0.6268, train acc: 0.811 | val loss: 2.7350, val acc: 0.417

-------------------- [cifar100] Experiment: NonlinearTPA_KV_shared --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=3.68]
Val   [1/20]: 100%|███████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=3.7]


[nonlinear_tpa] Epoch 1/20 | train loss: 4.0021, train acc: 0.087 | val loss: 3.6820, val acc: 0.130


Train [2/20]: 100%|█████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.9]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=3.48]


[nonlinear_tpa] Epoch 2/20 | train loss: 3.4853, train acc: 0.164 | val loss: 3.3531, val acc: 0.186


Train [3/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.97]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=3.17]


[nonlinear_tpa] Epoch 3/20 | train loss: 3.1738, train acc: 0.220 | val loss: 3.0920, val acc: 0.240


Train [4/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.64]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=2.71]


[nonlinear_tpa] Epoch 4/20 | train loss: 2.9149, train acc: 0.268 | val loss: 2.8712, val acc: 0.279


Train [5/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.79]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=2.61]


[nonlinear_tpa] Epoch 5/20 | train loss: 2.7240, train acc: 0.306 | val loss: 2.6960, val acc: 0.313


Train [6/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.98]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.01it/s, loss=2.46]


[nonlinear_tpa] Epoch 6/20 | train loss: 2.5608, train acc: 0.339 | val loss: 2.6370, val acc: 0.326


Train [7/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.44]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.00it/s, loss=2.42]


[nonlinear_tpa] Epoch 7/20 | train loss: 2.4225, train acc: 0.368 | val loss: 2.5660, val acc: 0.341


Train [8/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.31]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=2.15]


[nonlinear_tpa] Epoch 8/20 | train loss: 2.2901, train acc: 0.395 | val loss: 2.4445, val acc: 0.367


Train [9/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.14it/s, loss=2.14]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=2.25]


[nonlinear_tpa] Epoch 9/20 | train loss: 2.1837, train acc: 0.417 | val loss: 2.4603, val acc: 0.365


Train [10/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.3]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.00it/s, loss=2.41]


[nonlinear_tpa] Epoch 10/20 | train loss: 2.0482, train acc: 0.446 | val loss: 2.3961, val acc: 0.387


Train [11/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.67]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=2.21]


[nonlinear_tpa] Epoch 11/20 | train loss: 1.9315, train acc: 0.475 | val loss: 2.3615, val acc: 0.391


Train [12/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.08]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=1.98]


[nonlinear_tpa] Epoch 12/20 | train loss: 1.8060, train acc: 0.504 | val loss: 2.3537, val acc: 0.398


Train [13/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.64]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=1.99]


[nonlinear_tpa] Epoch 13/20 | train loss: 1.6880, train acc: 0.532 | val loss: 2.3120, val acc: 0.413


Train [14/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.87]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.00it/s, loss=1.79]


[nonlinear_tpa] Epoch 14/20 | train loss: 1.5475, train acc: 0.566 | val loss: 2.3375, val acc: 0.410


Train [15/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.47]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:20<00:00,  3.86it/s, loss=2.04]


[nonlinear_tpa] Epoch 15/20 | train loss: 1.4010, train acc: 0.601 | val loss: 2.4316, val acc: 0.401


Train [16/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.67]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.97it/s, loss=2.34]


[nonlinear_tpa] Epoch 16/20 | train loss: 1.2491, train acc: 0.639 | val loss: 2.5211, val acc: 0.398


Train [17/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.31]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=2.24]


[nonlinear_tpa] Epoch 17/20 | train loss: 1.0908, train acc: 0.681 | val loss: 2.5393, val acc: 0.410


Train [18/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.26]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.99it/s, loss=2.36]


[nonlinear_tpa] Epoch 18/20 | train loss: 0.9231, train acc: 0.726 | val loss: 2.6393, val acc: 0.404


Train [19/20]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.763]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=2.38]


[nonlinear_tpa] Epoch 19/20 | train loss: 0.7837, train acc: 0.766 | val loss: 2.7022, val acc: 0.411


Train [20/20]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.778]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  3.98it/s, loss=2.81]


[nonlinear_tpa] Epoch 20/20 | train loss: 0.6227, train acc: 0.812 | val loss: 2.8056, val acc: 0.410

-------------------- [cifar100] Experiment: NonlinearTPA_HW_KV --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=3.72]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=3.78]


[headwise_nonlinear_tpa] Epoch 1/20 | train loss: 4.0307, train acc: 0.080 | val loss: 3.7232, val acc: 0.114


Train [2/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=3.34]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=3.37]


[headwise_nonlinear_tpa] Epoch 2/20 | train loss: 3.5016, train acc: 0.161 | val loss: 3.3193, val acc: 0.196


Train [3/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=3.04]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=3.06]


[headwise_nonlinear_tpa] Epoch 3/20 | train loss: 3.2000, train acc: 0.215 | val loss: 3.0573, val acc: 0.243


Train [4/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.72]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=2.83]


[headwise_nonlinear_tpa] Epoch 4/20 | train loss: 2.9264, train acc: 0.267 | val loss: 2.8491, val acc: 0.287


Train [5/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.65]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=2.68]


[headwise_nonlinear_tpa] Epoch 5/20 | train loss: 2.7109, train acc: 0.308 | val loss: 2.6934, val acc: 0.315


Train [6/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.47]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=2.36]


[headwise_nonlinear_tpa] Epoch 6/20 | train loss: 2.5518, train acc: 0.341 | val loss: 2.5513, val acc: 0.343


Train [7/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.46]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=2.45]


[headwise_nonlinear_tpa] Epoch 7/20 | train loss: 2.4198, train acc: 0.367 | val loss: 2.5461, val acc: 0.348


Train [8/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.13]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=2.32]


[headwise_nonlinear_tpa] Epoch 8/20 | train loss: 2.2977, train acc: 0.394 | val loss: 2.4326, val acc: 0.371


Train [9/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=2.19]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=2.33]


[headwise_nonlinear_tpa] Epoch 9/20 | train loss: 2.1712, train acc: 0.422 | val loss: 2.4057, val acc: 0.377


Train [10/20]: 100%|████████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.9]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=2.19]


[headwise_nonlinear_tpa] Epoch 10/20 | train loss: 2.0542, train acc: 0.447 | val loss: 2.3679, val acc: 0.385


Train [11/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.86]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=2.31]


[headwise_nonlinear_tpa] Epoch 11/20 | train loss: 1.9363, train acc: 0.471 | val loss: 2.3577, val acc: 0.391


Train [12/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.95]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=2.39]


[headwise_nonlinear_tpa] Epoch 12/20 | train loss: 1.8124, train acc: 0.501 | val loss: 2.3073, val acc: 0.402


Train [13/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.59]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=2.11]


[headwise_nonlinear_tpa] Epoch 13/20 | train loss: 1.6815, train acc: 0.531 | val loss: 2.3139, val acc: 0.411


Train [14/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.62]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.12it/s, loss=2.11]


[headwise_nonlinear_tpa] Epoch 14/20 | train loss: 1.5373, train acc: 0.564 | val loss: 2.3262, val acc: 0.413


Train [15/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.56]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=2.79]


[headwise_nonlinear_tpa] Epoch 15/20 | train loss: 1.3808, train acc: 0.605 | val loss: 2.3712, val acc: 0.416


Train [16/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.44]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=2.18]


[headwise_nonlinear_tpa] Epoch 16/20 | train loss: 1.2275, train acc: 0.644 | val loss: 2.4443, val acc: 0.412


Train [17/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.02]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=2.44]


[headwise_nonlinear_tpa] Epoch 17/20 | train loss: 1.0504, train acc: 0.691 | val loss: 2.5160, val acc: 0.413


Train [18/20]: 100%|███████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=1.18]
Val   [18/20]: 100%|██████████████████████████████████████| 79/79 [00:19<00:00,  4.15it/s, loss=2.4]


[headwise_nonlinear_tpa] Epoch 18/20 | train loss: 0.8909, train acc: 0.734 | val loss: 2.5744, val acc: 0.419


Train [19/20]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.721]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=2.24]


[headwise_nonlinear_tpa] Epoch 19/20 | train loss: 0.7045, train acc: 0.789 | val loss: 2.7120, val acc: 0.414


Train [20/20]: 100%|██████████████████████████████████| 391/391 [03:03<00:00,  2.13it/s, loss=0.618]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.14it/s, loss=2.49]


[headwise_nonlinear_tpa] Epoch 20/20 | train loss: 0.5605, train acc: 0.833 | val loss: 2.7930, val acc: 0.417

-------------------- [cifar100] Experiment: NonlinearTPA_HW_KV --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=3.55]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=3.58]


[headwise_nonlinear_tpa] Epoch 1/20 | train loss: 4.0312, train acc: 0.082 | val loss: 3.6850, val acc: 0.134


Train [2/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=3.05]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=3.36]


[headwise_nonlinear_tpa] Epoch 2/20 | train loss: 3.4667, train acc: 0.167 | val loss: 3.3290, val acc: 0.195


Train [3/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=2.77]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=3.23]


[headwise_nonlinear_tpa] Epoch 3/20 | train loss: 3.1733, train acc: 0.219 | val loss: 3.0819, val acc: 0.236


Train [4/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=2.64]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.86]


[headwise_nonlinear_tpa] Epoch 4/20 | train loss: 2.9289, train acc: 0.266 | val loss: 2.8689, val acc: 0.283


Train [5/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=2.58]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.71]


[headwise_nonlinear_tpa] Epoch 5/20 | train loss: 2.7355, train acc: 0.304 | val loss: 2.7253, val acc: 0.311


Train [6/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=2.69]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.44]


[headwise_nonlinear_tpa] Epoch 6/20 | train loss: 2.5840, train acc: 0.334 | val loss: 2.6598, val acc: 0.325


Train [7/20]: 100%|█████████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=2.3]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.69]


[headwise_nonlinear_tpa] Epoch 7/20 | train loss: 2.4411, train acc: 0.364 | val loss: 2.6441, val acc: 0.328


Train [8/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.93]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.85]


[headwise_nonlinear_tpa] Epoch 8/20 | train loss: 2.3142, train acc: 0.389 | val loss: 2.5606, val acc: 0.348


Train [9/20]: 100%|█████████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=2.3]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.46]


[headwise_nonlinear_tpa] Epoch 9/20 | train loss: 2.2018, train acc: 0.413 | val loss: 2.4928, val acc: 0.362


Train [10/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.8]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.32]


[headwise_nonlinear_tpa] Epoch 10/20 | train loss: 2.0803, train acc: 0.441 | val loss: 2.3736, val acc: 0.386


Train [11/20]: 100%|███████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=2.01]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.56]


[headwise_nonlinear_tpa] Epoch 11/20 | train loss: 1.9553, train acc: 0.469 | val loss: 2.3993, val acc: 0.392


Train [12/20]: 100%|███████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.72]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.27]


[headwise_nonlinear_tpa] Epoch 12/20 | train loss: 1.8229, train acc: 0.498 | val loss: 2.3320, val acc: 0.406


Train [13/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.8]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.33]


[headwise_nonlinear_tpa] Epoch 13/20 | train loss: 1.6893, train acc: 0.532 | val loss: 2.3690, val acc: 0.404


Train [14/20]: 100%|████████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.7]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.69]


[headwise_nonlinear_tpa] Epoch 14/20 | train loss: 1.5476, train acc: 0.563 | val loss: 2.4290, val acc: 0.401


Train [15/20]: 100%|███████████████████████████████████| 391/391 [02:59<00:00,  2.18it/s, loss=1.37]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.62]


[headwise_nonlinear_tpa] Epoch 15/20 | train loss: 1.3925, train acc: 0.601 | val loss: 2.4365, val acc: 0.410


Train [16/20]: 100%|███████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=1.03]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.57]


[headwise_nonlinear_tpa] Epoch 16/20 | train loss: 1.2307, train acc: 0.644 | val loss: 2.5189, val acc: 0.402


Train [17/20]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.969]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.33]


[headwise_nonlinear_tpa] Epoch 17/20 | train loss: 1.0716, train acc: 0.682 | val loss: 2.6130, val acc: 0.402


Train [18/20]: 100%|███████████████████████████████████| 391/391 [03:00<00:00,  2.17it/s, loss=1.12]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.46]


[headwise_nonlinear_tpa] Epoch 18/20 | train loss: 0.8990, train acc: 0.736 | val loss: 2.6616, val acc: 0.405


Train [19/20]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.923]
Val   [19/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=2.9]


[headwise_nonlinear_tpa] Epoch 19/20 | train loss: 0.7491, train acc: 0.776 | val loss: 2.8037, val acc: 0.395


Train [20/20]: 100%|██████████████████████████████████| 391/391 [02:59<00:00,  2.17it/s, loss=0.555]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=3.04]


[headwise_nonlinear_tpa] Epoch 20/20 | train loss: 0.5910, train acc: 0.824 | val loss: 2.9265, val acc: 0.398

-------------------- [cifar100] Experiment: NonlinearTPA_HW_KV_shared --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=3.67]
Val   [1/20]: 100%|███████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=3.6]


[headwise_nonlinear_tpa] Epoch 1/20 | train loss: 4.0204, train acc: 0.084 | val loss: 3.6758, val acc: 0.132


Train [2/20]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=3.43]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=3.36]


[headwise_nonlinear_tpa] Epoch 2/20 | train loss: 3.4971, train acc: 0.163 | val loss: 3.3427, val acc: 0.195


Train [3/20]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=3.47]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=3.52]


[headwise_nonlinear_tpa] Epoch 3/20 | train loss: 3.2082, train acc: 0.214 | val loss: 3.1625, val acc: 0.224


Train [4/20]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=2.74]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.74]


[headwise_nonlinear_tpa] Epoch 4/20 | train loss: 2.9670, train acc: 0.260 | val loss: 2.9014, val acc: 0.275


Train [5/20]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=2.52]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.57]


[headwise_nonlinear_tpa] Epoch 5/20 | train loss: 2.7633, train acc: 0.296 | val loss: 2.7234, val acc: 0.314


Train [6/20]: 100%|█████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=2.2]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.45]


[headwise_nonlinear_tpa] Epoch 6/20 | train loss: 2.5885, train acc: 0.331 | val loss: 2.6232, val acc: 0.329


Train [7/20]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=2.22]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.51]


[headwise_nonlinear_tpa] Epoch 7/20 | train loss: 2.4471, train acc: 0.361 | val loss: 2.5775, val acc: 0.346


Train [8/20]: 100%|████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=2.12]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.20it/s, loss=1.99]


[headwise_nonlinear_tpa] Epoch 8/20 | train loss: 2.3128, train acc: 0.389 | val loss: 2.4757, val acc: 0.357


Train [9/20]: 100%|█████████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=2.5]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.18]


[headwise_nonlinear_tpa] Epoch 9/20 | train loss: 2.1876, train acc: 0.415 | val loss: 2.4118, val acc: 0.378


Train [10/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.84]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.19]


[headwise_nonlinear_tpa] Epoch 10/20 | train loss: 2.0652, train acc: 0.443 | val loss: 2.3927, val acc: 0.386


Train [11/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.91]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.37]


[headwise_nonlinear_tpa] Epoch 11/20 | train loss: 1.9385, train acc: 0.473 | val loss: 2.3446, val acc: 0.395


Train [12/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=2.12]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.15]


[headwise_nonlinear_tpa] Epoch 12/20 | train loss: 1.8071, train acc: 0.503 | val loss: 2.3814, val acc: 0.396


Train [13/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.74]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.94]


[headwise_nonlinear_tpa] Epoch 13/20 | train loss: 1.6762, train acc: 0.532 | val loss: 2.3621, val acc: 0.402


Train [14/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.59]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.35]


[headwise_nonlinear_tpa] Epoch 14/20 | train loss: 1.5351, train acc: 0.564 | val loss: 2.3620, val acc: 0.411


Train [15/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.91]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.65]


[headwise_nonlinear_tpa] Epoch 15/20 | train loss: 1.3881, train acc: 0.601 | val loss: 2.3776, val acc: 0.411


Train [16/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.29]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.16]


[headwise_nonlinear_tpa] Epoch 16/20 | train loss: 1.2348, train acc: 0.642 | val loss: 2.4528, val acc: 0.411


Train [17/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.11]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=2.59]


[headwise_nonlinear_tpa] Epoch 17/20 | train loss: 1.0666, train acc: 0.686 | val loss: 2.5375, val acc: 0.409


Train [18/20]: 100%|███████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=1.13]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.69]


[headwise_nonlinear_tpa] Epoch 18/20 | train loss: 0.9048, train acc: 0.730 | val loss: 2.6134, val acc: 0.407


Train [19/20]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.748]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.61]


[headwise_nonlinear_tpa] Epoch 19/20 | train loss: 0.7444, train acc: 0.776 | val loss: 2.7174, val acc: 0.408


Train [20/20]: 100%|██████████████████████████████████| 391/391 [03:02<00:00,  2.14it/s, loss=0.721]
Val   [20/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.3]


[headwise_nonlinear_tpa] Epoch 20/20 | train loss: 0.5989, train acc: 0.820 | val loss: 2.8484, val acc: 0.409

-------------------- [cifar100] Experiment: NonlinearTPA_QKV --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=3.52]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=3.71]


[nonlinear_tpa] Epoch 1/20 | train loss: 4.0331, train acc: 0.082 | val loss: 3.6597, val acc: 0.135


Train [2/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=3.21]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=3.24]


[nonlinear_tpa] Epoch 2/20 | train loss: 3.4920, train acc: 0.166 | val loss: 3.3724, val acc: 0.187


Train [3/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=3.09]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=3.19]


[nonlinear_tpa] Epoch 3/20 | train loss: 3.1931, train acc: 0.219 | val loss: 3.0907, val acc: 0.241


Train [4/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=3.01]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.83]


[nonlinear_tpa] Epoch 4/20 | train loss: 2.9521, train acc: 0.264 | val loss: 2.9001, val acc: 0.274


Train [5/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=2.32]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=2.74]


[nonlinear_tpa] Epoch 5/20 | train loss: 2.7670, train acc: 0.297 | val loss: 2.7569, val acc: 0.309


Train [6/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=2.63]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=2.45]


[nonlinear_tpa] Epoch 6/20 | train loss: 2.5995, train acc: 0.330 | val loss: 2.6684, val acc: 0.317


Train [7/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=2.19]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=2.31]


[nonlinear_tpa] Epoch 7/20 | train loss: 2.4588, train acc: 0.360 | val loss: 2.5391, val acc: 0.345


Train [8/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=2.67]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.16it/s, loss=2.36]


[nonlinear_tpa] Epoch 8/20 | train loss: 2.3369, train acc: 0.386 | val loss: 2.4945, val acc: 0.355


Train [9/20]: 100%|████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=2.33]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.19it/s, loss=2.35]


[nonlinear_tpa] Epoch 9/20 | train loss: 2.2162, train acc: 0.409 | val loss: 2.3914, val acc: 0.384


Train [10/20]: 100%|███████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=2.22]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=2.16]


[nonlinear_tpa] Epoch 10/20 | train loss: 2.0852, train acc: 0.439 | val loss: 2.3964, val acc: 0.383


Train [11/20]: 100%|███████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=1.84]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.16it/s, loss=1.86]


[nonlinear_tpa] Epoch 11/20 | train loss: 1.9715, train acc: 0.463 | val loss: 2.3656, val acc: 0.382


Train [12/20]: 100%|██████████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=2]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=2.26]


[nonlinear_tpa] Epoch 12/20 | train loss: 1.8399, train acc: 0.495 | val loss: 2.3148, val acc: 0.406


Train [13/20]: 100%|███████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=1.67]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:19<00:00,  4.13it/s, loss=1.93]


[nonlinear_tpa] Epoch 13/20 | train loss: 1.7058, train acc: 0.525 | val loss: 2.3538, val acc: 0.399


Train [14/20]: 100%|███████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=1.77]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.17it/s, loss=2.05]


[nonlinear_tpa] Epoch 14/20 | train loss: 1.5607, train acc: 0.559 | val loss: 2.3586, val acc: 0.405


Train [15/20]: 100%|███████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=1.64]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.29]


[nonlinear_tpa] Epoch 15/20 | train loss: 1.4031, train acc: 0.599 | val loss: 2.3708, val acc: 0.415


Train [16/20]: 100%|███████████████████████████████████| 391/391 [02:50<00:00,  2.30it/s, loss=1.44]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.33]


[nonlinear_tpa] Epoch 16/20 | train loss: 1.2376, train acc: 0.642 | val loss: 2.4507, val acc: 0.409


Train [17/20]: 100%|███████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=1.51]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.08]


[nonlinear_tpa] Epoch 17/20 | train loss: 1.0625, train acc: 0.687 | val loss: 2.5159, val acc: 0.410


Train [18/20]: 100%|███████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=0.82]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=1.99]


[nonlinear_tpa] Epoch 18/20 | train loss: 0.8796, train acc: 0.738 | val loss: 2.6119, val acc: 0.407


Train [19/20]: 100%|██████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=0.719]
Val   [19/20]: 100%|████████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2]


[nonlinear_tpa] Epoch 19/20 | train loss: 0.6942, train acc: 0.792 | val loss: 2.7544, val acc: 0.399


Train [20/20]: 100%|██████████████████████████████████| 391/391 [02:49<00:00,  2.30it/s, loss=0.662]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.18it/s, loss=2.38]


[nonlinear_tpa] Epoch 20/20 | train loss: 0.5520, train acc: 0.835 | val loss: 2.8719, val acc: 0.403

-------------------- [cifar100] Experiment: NonlinearTPA_Q --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=3.68]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=3.73]


[nonlinear_tpa] Epoch 1/20 | train loss: 3.8931, train acc: 0.105 | val loss: 3.5973, val acc: 0.149


Train [2/20]: 100%|████████████████████████████████████| 391/391 [02:45<00:00,  2.36it/s, loss=2.92]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=3.37]


[nonlinear_tpa] Epoch 2/20 | train loss: 3.3744, train acc: 0.188 | val loss: 3.2804, val acc: 0.203


Train [3/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=3.04]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=3.01]


[nonlinear_tpa] Epoch 3/20 | train loss: 3.0975, train acc: 0.237 | val loss: 3.0056, val acc: 0.261


Train [4/20]: 100%|████████████████████████████████████| 391/391 [02:45<00:00,  2.36it/s, loss=2.93]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.78]


[nonlinear_tpa] Epoch 4/20 | train loss: 2.8469, train acc: 0.283 | val loss: 2.8074, val acc: 0.299


Train [5/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.63]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.58]


[nonlinear_tpa] Epoch 5/20 | train loss: 2.6388, train acc: 0.324 | val loss: 2.6774, val acc: 0.317


Train [6/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.37]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.31]


[nonlinear_tpa] Epoch 6/20 | train loss: 2.4682, train acc: 0.357 | val loss: 2.5503, val acc: 0.343


Train [7/20]: 100%|█████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.6]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.49]


[nonlinear_tpa] Epoch 7/20 | train loss: 2.3198, train acc: 0.388 | val loss: 2.5130, val acc: 0.357


Train [8/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.23]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.28]


[nonlinear_tpa] Epoch 8/20 | train loss: 2.1816, train acc: 0.416 | val loss: 2.4242, val acc: 0.378


Train [9/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.72]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=2.17]


[nonlinear_tpa] Epoch 9/20 | train loss: 2.0491, train acc: 0.448 | val loss: 2.3924, val acc: 0.385


Train [10/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.87]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.51]


[nonlinear_tpa] Epoch 10/20 | train loss: 1.9083, train acc: 0.479 | val loss: 2.3942, val acc: 0.387


Train [11/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.85]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.25]


[nonlinear_tpa] Epoch 11/20 | train loss: 1.7793, train acc: 0.509 | val loss: 2.3755, val acc: 0.398


Train [12/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.66]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.44]


[nonlinear_tpa] Epoch 12/20 | train loss: 1.6288, train acc: 0.547 | val loss: 2.4173, val acc: 0.397


Train [13/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.78]
Val   [13/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.5]


[nonlinear_tpa] Epoch 13/20 | train loss: 1.4795, train acc: 0.585 | val loss: 2.4122, val acc: 0.407


Train [14/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.15]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.09]


[nonlinear_tpa] Epoch 14/20 | train loss: 1.3283, train acc: 0.619 | val loss: 2.4869, val acc: 0.399


Train [15/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.16]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.42]


[nonlinear_tpa] Epoch 15/20 | train loss: 1.1617, train acc: 0.665 | val loss: 2.5184, val acc: 0.404


Train [16/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.06]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=2.13]


[nonlinear_tpa] Epoch 16/20 | train loss: 1.0016, train acc: 0.707 | val loss: 2.5999, val acc: 0.405


Train [17/20]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.742]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.49]


[nonlinear_tpa] Epoch 17/20 | train loss: 0.8285, train acc: 0.755 | val loss: 2.6900, val acc: 0.403


Train [18/20]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.698]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.47]


[nonlinear_tpa] Epoch 18/20 | train loss: 0.6943, train acc: 0.794 | val loss: 2.7947, val acc: 0.402


Train [19/20]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.768]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.64]


[nonlinear_tpa] Epoch 19/20 | train loss: 0.5469, train acc: 0.838 | val loss: 2.9558, val acc: 0.388


Train [20/20]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.432]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=3.19]


[nonlinear_tpa] Epoch 20/20 | train loss: 0.4347, train acc: 0.873 | val loss: 3.0438, val acc: 0.396

-------------------- [cifar100] Experiment: NonlinearTPA_K --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=3.46]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=3.56]


[nonlinear_tpa] Epoch 1/20 | train loss: 3.9049, train acc: 0.104 | val loss: 3.5674, val acc: 0.156


Train [2/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=3.37]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=3.26]


[nonlinear_tpa] Epoch 2/20 | train loss: 3.3861, train acc: 0.185 | val loss: 3.2702, val acc: 0.209


Train [3/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.93]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=3.19]


[nonlinear_tpa] Epoch 3/20 | train loss: 3.0846, train acc: 0.239 | val loss: 3.0235, val acc: 0.257


Train [4/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.64]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=3.08]


[nonlinear_tpa] Epoch 4/20 | train loss: 2.8456, train acc: 0.281 | val loss: 2.8357, val acc: 0.294


Train [5/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.59]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.63]


[nonlinear_tpa] Epoch 5/20 | train loss: 2.6552, train acc: 0.319 | val loss: 2.7412, val acc: 0.309


Train [6/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.41]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.62]


[nonlinear_tpa] Epoch 6/20 | train loss: 2.4867, train acc: 0.355 | val loss: 2.5728, val acc: 0.341


Train [7/20]: 100%|█████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.6]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=2.67]


[nonlinear_tpa] Epoch 7/20 | train loss: 2.3458, train acc: 0.384 | val loss: 2.5388, val acc: 0.351


Train [8/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.01]
Val   [8/20]: 100%|███████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.6]


[nonlinear_tpa] Epoch 8/20 | train loss: 2.2042, train acc: 0.415 | val loss: 2.5354, val acc: 0.353


Train [9/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.21]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=2.25]


[nonlinear_tpa] Epoch 9/20 | train loss: 2.0758, train acc: 0.442 | val loss: 2.4451, val acc: 0.374


Train [10/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.06]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.06]


[nonlinear_tpa] Epoch 10/20 | train loss: 1.9397, train acc: 0.471 | val loss: 2.4444, val acc: 0.380


Train [11/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.34it/s, loss=1.82]
Val   [11/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.21]


[nonlinear_tpa] Epoch 11/20 | train loss: 1.8098, train acc: 0.504 | val loss: 2.4373, val acc: 0.380


Train [12/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.7]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.39]


[nonlinear_tpa] Epoch 12/20 | train loss: 1.6654, train acc: 0.535 | val loss: 2.3930, val acc: 0.401


Train [13/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.51]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.11]


[nonlinear_tpa] Epoch 13/20 | train loss: 1.5129, train acc: 0.572 | val loss: 2.4550, val acc: 0.394


Train [14/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.51]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=1.97]


[nonlinear_tpa] Epoch 14/20 | train loss: 1.3633, train acc: 0.610 | val loss: 2.4886, val acc: 0.395


Train [15/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.29]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=1.75]


[nonlinear_tpa] Epoch 15/20 | train loss: 1.2008, train acc: 0.653 | val loss: 2.5428, val acc: 0.393


Train [16/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.33]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.21]


[nonlinear_tpa] Epoch 16/20 | train loss: 1.0454, train acc: 0.694 | val loss: 2.6135, val acc: 0.393


Train [17/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.02]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=1.38]


[nonlinear_tpa] Epoch 17/20 | train loss: 0.8631, train acc: 0.746 | val loss: 2.7241, val acc: 0.391


Train [18/20]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.855]
Val   [18/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.15]


[nonlinear_tpa] Epoch 18/20 | train loss: 0.7033, train acc: 0.792 | val loss: 2.8105, val acc: 0.402


Train [19/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.94]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=1.99]


[nonlinear_tpa] Epoch 19/20 | train loss: 0.5738, train acc: 0.831 | val loss: 3.0088, val acc: 0.383


Train [20/20]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.801]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.21it/s, loss=2.03]


[nonlinear_tpa] Epoch 20/20 | train loss: 0.4722, train acc: 0.860 | val loss: 3.1157, val acc: 0.391

-------------------- [cifar100] Experiment: NonlinearTPA_V --------------------


Train [1/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=3.59]
Val   [1/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=3.74]


[nonlinear_tpa] Epoch 1/20 | train loss: 3.9856, train acc: 0.090 | val loss: 3.6581, val acc: 0.139


Train [2/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=3.04]
Val   [2/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=3.45]


[nonlinear_tpa] Epoch 2/20 | train loss: 3.4666, train acc: 0.168 | val loss: 3.3087, val acc: 0.201


Train [3/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.89]
Val   [3/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=3.14]


[nonlinear_tpa] Epoch 3/20 | train loss: 3.1839, train acc: 0.220 | val loss: 3.0729, val acc: 0.241


Train [4/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.82]
Val   [4/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.62]


[nonlinear_tpa] Epoch 4/20 | train loss: 2.9514, train acc: 0.262 | val loss: 2.8821, val acc: 0.280


Train [5/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.29]
Val   [5/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.59]


[nonlinear_tpa] Epoch 5/20 | train loss: 2.7475, train acc: 0.301 | val loss: 2.7643, val acc: 0.303


Train [6/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.61]
Val   [6/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.68]


[nonlinear_tpa] Epoch 6/20 | train loss: 2.5747, train acc: 0.335 | val loss: 2.6599, val acc: 0.316


Train [7/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.41]
Val   [7/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.19]


[nonlinear_tpa] Epoch 7/20 | train loss: 2.4269, train acc: 0.365 | val loss: 2.5273, val acc: 0.348


Train [8/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.33]
Val   [8/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.28]


[nonlinear_tpa] Epoch 8/20 | train loss: 2.2978, train acc: 0.394 | val loss: 2.4965, val acc: 0.358


Train [9/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=2.22]
Val   [9/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.31]


[nonlinear_tpa] Epoch 9/20 | train loss: 2.1816, train acc: 0.418 | val loss: 2.4532, val acc: 0.369


Train [10/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.98]
Val   [10/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=2.09]


[nonlinear_tpa] Epoch 10/20 | train loss: 2.0488, train acc: 0.447 | val loss: 2.3719, val acc: 0.389


Train [11/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.81]
Val   [11/20]: 100%|██████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.4]


[nonlinear_tpa] Epoch 11/20 | train loss: 1.9309, train acc: 0.474 | val loss: 2.3679, val acc: 0.394


Train [12/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.74]
Val   [12/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.22]


[nonlinear_tpa] Epoch 12/20 | train loss: 1.7999, train acc: 0.503 | val loss: 2.3967, val acc: 0.386


Train [13/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.32]
Val   [13/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=1.85]


[nonlinear_tpa] Epoch 13/20 | train loss: 1.6625, train acc: 0.535 | val loss: 2.3502, val acc: 0.407


Train [14/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.82]
Val   [14/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.02]


[nonlinear_tpa] Epoch 14/20 | train loss: 1.5281, train acc: 0.568 | val loss: 2.3854, val acc: 0.406


Train [15/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.44]
Val   [15/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.25it/s, loss=1.93]


[nonlinear_tpa] Epoch 15/20 | train loss: 1.3753, train acc: 0.607 | val loss: 2.3872, val acc: 0.415


Train [16/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.34]
Val   [16/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.13]


[nonlinear_tpa] Epoch 16/20 | train loss: 1.2331, train acc: 0.644 | val loss: 2.4264, val acc: 0.415


Train [17/20]: 100%|████████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.1]
Val   [17/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2.38]


[nonlinear_tpa] Epoch 17/20 | train loss: 1.0789, train acc: 0.683 | val loss: 2.5136, val acc: 0.412


Train [18/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.15]
Val   [18/20]: 100%|████████████████████████████████████████| 79/79 [00:18<00:00,  4.24it/s, loss=2]


[nonlinear_tpa] Epoch 18/20 | train loss: 0.9090, train acc: 0.731 | val loss: 2.5957, val acc: 0.416


Train [19/20]: 100%|███████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=1.04]
Val   [19/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.23it/s, loss=2.18]


[nonlinear_tpa] Epoch 19/20 | train loss: 0.7692, train acc: 0.770 | val loss: 2.7115, val acc: 0.406


Train [20/20]: 100%|██████████████████████████████████| 391/391 [02:46<00:00,  2.35it/s, loss=0.797]
Val   [20/20]: 100%|█████████████████████████████████████| 79/79 [00:18<00:00,  4.22it/s, loss=1.97]


[nonlinear_tpa] Epoch 20/20 | train loss: 0.6333, train acc: 0.809 | val loss: 2.7883, val acc: 0.414


import os
import json
import datetime
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.optim import AdamW

# ==== 通用配置 ====
data_dir = "./data"
img_size = 224
num_workers = 8
weight_decay = 0.05
pretrained = False

device = "cuda" if torch.cuda.is_available() else "cpu"
use_amp = False      # 先关 AMP，等结构都稳定后再考虑开

# 数据集 & 模型
dataset_name = "cifar100"     # 换数据集只改这里
model_tag   = "tiny"
model_name  = "vit_tiny_patch16_224"
batch_size  = 128
lr          = 3e-4

# ====== epoch 设置 ======
epochs_base         = 30   # TPA / NonlinearTPA / SinterTPA 的总 epoch
epochs_mha          = 100  # MHA 的总 epoch
tucker_ab_epochs    = 11   # TuckerTPA 阶段 1：训练 A/B 的 epoch 数
tucker_total_epochs = 20   # TuckerTPA 总 epoch

# ====== rank 设置：共享单一 R ======
R_shared = 16
tucker_rank_head    = R_shared
tucker_rank_channel = R_shared

# TPA / NonlinearTPA / SinterTPA 的 rank 设置
rank_q = 16
rank_k = 2
rank_v = 2

# ====== Sinter 超参数（扰动） ======
sinter_A = 5e-5
sinter_omega = 1e4

# ====== 实验列表 ======
# 顺序：
#   1) SinterTPA (ratio=1.0, 30ep)
#   2) NonlinearTPA (ratio=1.0, 30ep)
#   3) MHA (100ep)
#   4) SinterTPA (ratio=2.0, 30ep)
#   5) NonlinearTPA (ratio=2.0, 30ep)
#   6) SinterTPA (ratio=3.0, 30ep)
#   7) NonlinearTPA (ratio=3.0, 30ep)
experiments = [
    {
        "name": "SinterTPA_ratio1.0",
        "attn_type": "sinter_tpa",
        "total_epochs": epochs_base,
        "mlp_ratio": 1.0,
    },
    {
        "name": "NonlinearTPA_ratio1.0",
        "attn_type": "nonlinear_tpa",
        "total_epochs": epochs_base,
        "mlp_ratio": 1.0,
    },
    {
        "name": "MHA",
        "attn_type": "mha",
        "total_epochs": epochs_mha,
        "mlp_ratio": None,
    },
    {
        "name": "SinterTPA_ratio2.0",
        "attn_type": "sinter_tpa",
        "total_epochs": epochs_base,
        "mlp_ratio": 2.0,
    },
    {
        "name": "NonlinearTPA_ratio2.0",
        "attn_type": "nonlinear_tpa",
        "total_epochs": epochs_base,
        "mlp_ratio": 2.0,
    },
    {
        "name": "SinterTPA_ratio3.0",
        "attn_type": "sinter_tpa",
        "total_epochs": epochs_base,
        "mlp_ratio": 3.0,
    },
    {
        "name": "NonlinearTPA_ratio3.0",
        "attn_type": "nonlinear_tpa",
        "total_epochs": epochs_base,
        "mlp_ratio": 3.0,
    },
]

# ============== 准备 result 目录和本次 run 的文件夹 ==============

os.makedirs("result", exist_ok=True)

run_name = (
    f"{dataset_name}_{model_tag}_{model_name}"
    f"_BaseEp{epochs_base}_MHAEp{epochs_mha}"
    f"_bs{batch_size}_lr{lr}_wd{weight_decay}"
    f"_rq{rank_q}_rk{rank_k}_rv{rank_v}"
    f"_attns_SinterTPA_NonlinearTPA_MHA"
)

run_dir = os.path.join("result", run_name)
os.makedirs(run_dir, exist_ok=True)

print(f"\n\n######################## Run dir: {run_dir} ########################")

# 把本次实验的超参数也存一个 json 方便以后查
tested_mlp_ratios = sorted(
    {e["mlp_ratio"] for e in experiments if e.get("mlp_ratio") is not None}
)

hparams = {
    "dataset_name": dataset_name,
    "model_tag": model_tag,
    "model_name": model_name,
    "epochs_base": epochs_base,
    "epochs_mha": epochs_mha,
    "tucker_total_epochs": tucker_total_epochs,
    "tucker_ab_epochs": tucker_ab_epochs,
    "batch_size": batch_size,
    "lr": lr,
    "weight_decay": weight_decay,
    "R_shared": R_shared,
    "rank_q": rank_q,
    "rank_k": rank_k,
    "rank_v": rank_v,
    "tucker_rank_head": tucker_rank_head,
    "tucker_rank_channel": tucker_rank_channel,
    "sinter_A": sinter_A,
    "sinter_omega": sinter_omega,
    "nonlinear_mlp_hidden_ratios": tested_mlp_ratios,
    "experiments": [e["name"] for e in experiments],
    "timestamp": datetime.datetime.now().isoformat(),
}

with open(os.path.join(run_dir, "hparams.json"), "w") as f:
    json.dump(hparams, f, indent=2)

# ============== 数据加载 ==============

print(f"\n\n######################## Experiments on {model_tag} ({model_name}) / {dataset_name} ########################")

train_loader, val_loader, num_classes = get_loaders(
    dataset_name=dataset_name,
    data_dir=data_dir,
    batch_size=batch_size,
    img_size=img_size,
    num_workers=num_workers,
)

# ============== 1) Baseline MHA 参数量 / 维度信息 ==============

print("\n==================== Baseline MHA model ====================")
baseline_model = ViTClassifier(
    num_classes=num_classes,
    model_name=model_name,
    pretrained=pretrained,
    attn_type="mha",
).to(device)

baseline_params = sum(p.numel() for p in baseline_model.parameters())
print(f"[{model_tag}] Baseline MHA params: {baseline_params / 1e6:.2f}M")

first_attn_baseline = baseline_model.vit.blocks[0].attn
if hasattr(first_attn_baseline, "dim"):
    dim = first_attn_baseline.dim
else:
    dim = first_attn_baseline.qkv.in_features

num_heads = first_attn_baseline.num_heads
head_dim = dim // num_heads
print(f"[{model_tag}] dim={dim}, num_heads={num_heads}, head_dim={head_dim}")

del baseline_model
if device == "cuda":
    torch.cuda.empty_cache()

# ============== 2) 依次跑多个实验 ==============

results_all = []      # 用来存最后总曲线图的信息

for exp in experiments:
    exp_name     = exp["name"]
    attn_type    = exp["attn_type"]
    total_epochs = exp["total_epochs"]
    mlp_ratio    = exp.get("mlp_ratio", None)

    print(f"\n==================== Experiment: {model_tag}-{exp_name} ====================")

    # 为当前实验建一个子目录
    exp_dir = os.path.join(run_dir, exp_name)
    os.makedirs(exp_dir, exist_ok=True)

    # ---------- 构建模型 ----------
    if attn_type == "mha":
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="mha",
        ).to(device)

    elif attn_type == "tpa":
        # True TPA（A/B 都 contextual）
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
        ).to(device)

    elif attn_type == "nonlinear_tpa":
        # 非线性 TPA：Q = MLP(1/R * A^T B)
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="nonlinear_tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            nonlinear_mlp_hidden_ratio=mlp_ratio if mlp_ratio is not None else 1.0,
        ).to(device)

    elif attn_type == "sinter_tpa":
        # Sinter TPA：Q = MLP_Sinter(1/R * A^T B)
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="sinter_tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
            nonlinear_mlp_hidden_ratio=mlp_ratio if mlp_ratio is not None else 1.0,
            sinter_A=sinter_A,
            sinter_omega=sinter_omega,
        ).to(device)

    elif attn_type == "tucker_tpa":
        model = ViTClassifier(
            num_classes=num_classes,
            model_name=model_name,
            pretrained=pretrained,
            attn_type="tucker_tpa",
            rank_q=rank_q,
            rank_k=rank_k,
            rank_v=rank_v,
        ).to(device)

    else:
        raise ValueError(f"Unexpected attn_type: {attn_type}")

    total_params = sum(p.numel() for p in model.parameters())
    print(
        f"[{model_tag}-{exp_name}] Total params: {total_params / 1e6:.2f}M "
        f"(diff vs MHA: {(total_params - baseline_params)/1e6:.3f}M)"
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    if use_amp and device == "cuda":
        scaler = torch.cuda.amp.GradScaler()
    else:
        scaler = None

    # ---------- 调用对应的训练函数 ----------
    if attn_type == "tucker_tpa":
        history = train_model_two_stage_tucker(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            total_epochs=total_epochs,
            ab_epochs=tucker_ab_epochs,
            scaler=scaler,
            is_tucker_model=True,
        )

    elif attn_type == "tpa":
        history = train_model_TPA(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            total_epochs=total_epochs,
            scaler=scaler,
        )

    elif attn_type == "nonlinear_tpa":
        history = train_model_nonlinear_TPA(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            total_epochs=total_epochs,
            scaler=scaler,
        )

    elif attn_type == "sinter_tpa":
        history = train_model_sinter_TPA(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            total_epochs=total_epochs,
            scaler=scaler,
        )

    elif attn_type == "mha":
        history = train_model_MHA(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            total_epochs=total_epochs,
            scaler=scaler,
        )

    else:
        raise ValueError(f"Unknown attn_type: {attn_type}")

    train_loss_curve = history["train_loss_curve"]
    val_loss_curve   = history["val_loss_curve"]
    train_acc_curve  = history["train_acc_curve"]
    val_acc_curve    = history["val_acc_curve"]
    best_val_acc     = history["best_val_acc"]

    print(f">>> [{model_tag}-{exp_name}] best val acc: {best_val_acc:.4f}")

    # ---------- KV cost 简易估算 ----------
    first_attn = model.vit.blocks[0].attn
    if hasattr(first_attn, "dim"):
        dim_attn = first_attn.dim
    else:
        dim_attn = first_attn.qkv.in_features
    num_heads_attn = first_attn.num_heads
    head_dim_attn = dim_attn // num_heads_attn

    kv_mha = 2 * num_heads_attn * head_dim_attn  # = 2 * dim

    if attn_type == "mha":
        normalized_kv_cost = 1.0
    elif attn_type in ("tpa", "nonlinear_tpa", "sinter_tpa"):
        kv_tpa = (rank_k + rank_v) * (num_heads_attn + head_dim_attn)
        normalized_kv_cost = kv_tpa / kv_mha
    elif attn_type == "tucker_tpa":
        normalized_kv_cost = 1.0
    else:
        normalized_kv_cost = 1.0

    # ---------- 把当前实验结果存起来 ----------
    results_all.append(
        {
            "name": exp_name,
            "attn_type": attn_type,
            "mlp_ratio": mlp_ratio,
            "params": total_params,
            "best_val_acc": best_val_acc,
            "train_loss_curve": train_loss_curve,
            "val_loss_curve": val_loss_curve,
            "train_acc_curve": train_acc_curve,
            "val_acc_curve": val_acc_curve,
            "kv_cost": normalized_kv_cost,
        }
    )

    # ---------- 写 log ----------
    log_lines = []
    # 第一行：模型参数量信息
    log_lines.append(
        f"Total params: {total_params} ({total_params/1e6:.2f}M), "
        f"diff vs baseline MHA: {total_params - baseline_params} "
        f"({(total_params - baseline_params)/1e6:.3f}M)"
    )
    if mlp_ratio is not None:
        log_lines.append(f"nonlinear_mlp_hidden_ratio: {mlp_ratio}")
    log_lines.append(f"attn_type: {attn_type}")
    log_lines.append(f"kv_cost (normalized to MHA=1): {normalized_kv_cost:.4f}")
    log_lines.append("")  # 空行分隔

    for ep in range(total_epochs):
        log_lines.append(
            f"Epoch {ep+1}/{total_epochs} | "
            f"train loss: {train_loss_curve[ep]:.4f}, train acc: {train_acc_curve[ep]:.3f} | "
            f"val loss: {val_loss_curve[ep]:.4f}, val acc: {val_acc_curve[ep]:.3f}"
        )

    exp_log_path = os.path.join(
        exp_dir,
        f"log_{dataset_name}_{model_tag}_{exp_name}_totEp{total_epochs}_lr{lr}_R{R_shared}_rq{rank_q}_rh{tucker_rank_head}_rd{tucker_rank_channel}.txt",
    )
    with open(exp_log_path, "w") as f:
        f.write("\n".join(log_lines))

    # ---------- 保存最终 checkpoint（最新一版） ----------
    ckpt_name = (
        f"ckpt_{dataset_name}_{model_tag}_{exp_name}"
        f"_totEp{total_epochs}_lr{lr}_R{R_shared}_rq{rank_q}_rh{tucker_rank_head}_rd{tucker_rank_channel}.pt"
    )
    ckpt_path = os.path.join(exp_dir, ckpt_name)

    torch.save(
        {
            "attn_type": attn_type,
            "total_epochs": total_epochs,
            "mlp_ratio": mlp_ratio,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_val_acc": best_val_acc,
            "train_loss_curve": train_loss_curve,
            "val_loss_curve": val_loss_curve,
            "train_acc_curve": train_acc_curve,
            "val_acc_curve": val_acc_curve,
            "hparams": hparams,
        },
        ckpt_path,
    )

    # ---------- 画当前实验的 loss 收敛图 ----------
    epochs_range = range(1, total_epochs + 1)

    plt.figure()
    plt.plot(epochs_range, train_loss_curve, label="train loss")
    plt.plot(epochs_range, val_loss_curve, label="val loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"{dataset_name} - {model_tag} - {exp_name}: loss convergence")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    loss_fig_path = os.path.join(
        exp_dir,
        f"loss_curve_{dataset_name}_{model_tag}_{exp_name}_totEp{total_epochs}.png",
    )
    plt.savefig(loss_fig_path)
    plt.show()

    # ---------- 当前实验的 val acc 曲线 ----------
    plt.figure()
    plt.plot(epochs_range, val_acc_curve, label=f"{exp_name} val acc")
    plt.xlabel("Epoch")
    plt.ylabel("Val accuracy")
    plt.title(f"{dataset_name} - {model_tag} - {exp_name}: val acc")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    acc_fig_path = os.path.join(
        exp_dir,
        f"val_acc_curve_{dataset_name}_{model_tag}_{exp_name}_totEp{total_epochs}.png",
    )
    plt.savefig(acc_fig_path)
    plt.show()

    # ---------- 截至当前 experiment 的所有实验对比图（val acc & val loss） ----------
    # Val accuracy comparison
    plt.figure()
    for res in results_all:
        total_epochs_exp = len(res["val_acc_curve"])
        epochs_range_all = range(1, total_epochs_exp + 1)
        label = res["name"]
        if res.get("mlp_ratio") is not None:
            label += f" (ratio={res['mlp_ratio']})"
        label += f" (KV={res['kv_cost']:.2f})"
        plt.plot(epochs_range_all, res["val_acc_curve"], label=label)
    plt.xlabel("Epoch")
    plt.ylabel("Val accuracy")
    plt.title(f"{dataset_name} - {model_tag}: Val acc comparison (up to {exp_name})")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    combined_acc_path = os.path.join(
        run_dir,
        f"combined_val_acc_upto_{exp_name}_BaseEp{epochs_base}_MHAEp{epochs_mha}_R{R_shared}.png",
    )
    plt.savefig(combined_acc_path)
    plt.show()

    # Val loss comparison
    plt.figure()
    for res in results_all:
        total_epochs_exp = len(res["val_loss_curve"])
        epochs_range_all = range(1, total_epochs_exp + 1)
        label = res["name"]
        if res.get("mlp_ratio") is not None:
            label += f" (ratio={res['mlp_ratio']})"
        label += f" (KV={res['kv_cost']:.2f})"
        plt.plot(epochs_range_all, res["val_loss_curve"], label=label)
    plt.xlabel("Epoch")
    plt.ylabel("Val loss")
    plt.title(f"{dataset_name} - {model_tag}: Val loss comparison (up to {exp_name})")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    combined_loss_path = os.path.join(
        run_dir,
        f"combined_val_loss_upto_{exp_name}_BaseEp{epochs_base}_MHAEp{epochs_mha}_R{R_shared}.png",
    )
    plt.savefig(combined_loss_path)
    plt.show()

    del model
    if device == "cuda":
        torch.cuda.empty_cache()

# ============== 3) 写 summary.json ==============

summary = []
for res in results_all:
    summary.append(
        {
            "name": res["name"],
            "attn_type": res["attn_type"],
            "mlp_ratio": res["mlp_ratio"],
            "params_million": res["params"] / 1e6,
            "best_val_acc": float(res["best_val_acc"]),
            "kv_cost": float(res["kv_cost"]),
            "epochs": len(res["val_acc_curve"]),
        }
    )

with open(os.path.join(run_dir, "summary.json"), "w") as f:
    json.dump(summary, f, indent=2)
