In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import math

Z_DIM = 64
BETA = 1e-1
EPOCHS = 20
BATCH_SIZE = 128
LR = 1e-3
SEED = 42
NUM_WORKERS = 2
MC_EVAL = 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_ROOT = "./data"
IN_DIM = 28 * 28
NUM_CLASSES = 10


In [3]:
def seed_all(s: int) -> None:
    random.seed(s)
    torch.manual_seed(s)

class MLPEncoder(nn.Module):
    def __init__(self, in_dim: int, z_dim: int, hidden=(512, 256)) -> None:
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.ReLU(inplace=True)]
            last = h
        self.net = nn.Sequential(*layers)
        self.mu = nn.Linear(last, z_dim)
        self.logvar = nn.Linear(last, z_dim)

    def forward(self, x: torch.Tensor):
        h = self.net(x)
        return self.mu(h), self.logvar(h)

class Classifier(nn.Module):
    def __init__(self, z_dim: int, num_classes: int) -> None:
        super().__init__()
        self.head = nn.Linear(z_dim, num_classes)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return self.head(z)

class VIB(nn.Module):
    def __init__(self, in_dim: int, z_dim: int, num_classes: int, hidden=(512, 256)) -> None:
        super().__init__()
        self.encoder = MLPEncoder(in_dim, z_dim, hidden)
        self.classifier = Classifier(z_dim, num_classes)

    @staticmethod
    def reparam(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    @staticmethod
    def kl(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        return 0.5 * (mu.pow(2) + logvar.exp() - 1.0 - logvar).sum(dim=1)

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encoder(x)
        z = self.reparam(mu, logvar)
        logits = self.classifier(z)
        return logits, mu, logvar

def get_mnist():
    t = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda t: t.view(-1))])
    tr = datasets.MNIST(DATA_ROOT, train=True, download=True, transform=t)
    te = datasets.MNIST(DATA_ROOT, train=False, download=True, transform=t)
    in_dim = 28 * 28
    num_classes = 10
    tr_loader = DataLoader(tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    te_loader = DataLoader(te, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    return tr_loader, te_loader, in_dim, num_classes


@torch.no_grad()
def evaluate(model: VIB, loader: DataLoader):
    model.eval()
    ce_list = []
    kl_list = []
    correct = 0
    total = 0
    for x, y in loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        logits_acc = 0
        kl_acc = 0
        for _ in range(max(1, MC_EVAL)):
            logits, mu, logvar = model(x)
            logits_acc = logits_acc + logits
            kl_acc = kl_acc + VIB.kl(mu, logvar)
        logits = logits_acc / max(1, MC_EVAL)
        kl = kl_acc / max(1, MC_EVAL)
        ce = F.cross_entropy(logits, y, reduction="none")
        ce_list.append(ce.detach())
        kl_list.append(kl.detach())
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    ce = torch.cat(ce_list).mean().item()
    kl = torch.cat(kl_list).mean().item()
    acc = correct / max(1, total)
    return ce, kl, acc


@torch.no_grad()
def estimate_mi_gaussian(z1: torch.Tensor, z2: torch.Tensor, eps: float = 1e-5) -> float:
    """
    z1: [N, d1] tensor of samples from model 1
    z2: [N, d2] tensor of samples from model 2
    Returns: scalar MI estimate (nats) under a joint-Gaussian plug-in model.
    """
    assert z1.shape[0] == z2.shape[0]
    X = torch.cat([z1, z2], dim=1)
    X = X - X.mean(dim=0, keepdim=True)

    N = X.shape[0]
    Sigma = (X.T @ X) / (N - 1)
    d1 = z1.shape[1]

    eye = torch.eye(Sigma.shape[0], device=Sigma.device)
    Sigma = Sigma + eps * eye

    S11 = Sigma[:d1, :d1]
    S22 = Sigma[d1:, d1:]

    # logdet sigma
    sign_full, logdet_full = torch.slogdet(Sigma)
    sign_11, logdet_11 = torch.slogdet(S11)
    sign_22, logdet_22 = torch.slogdet(S22)
    if (sign_full <= 0) or (sign_11 <= 0) or (sign_22 <= 0):
        raise ValueError("Non–PD covariance")

    mi = 0.5 * (logdet_11 + logdet_22 - logdet_full)  # nats
    return mi.item()



In [4]:
# Train teacher

seed_all(SEED)
tr_loader, te_loader, IN_DIM, NUM_CLASSES = get_mnist()
model = VIB(IN_DIM, Z_DIM, NUM_CLASSES).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)

train_acc, val_acc = [], []
train_loss, val_loss = [], []
train_ixz, val_ixz = [], []
train_iyz, val_iyz = [], []
H_Y = math.log(NUM_CLASSES) # we assume classes follow an uniform dist

for epoch in range(1, EPOCHS + 1):
    model.train()
    p = tqdm(tr_loader, desc=f"epoch {epoch}/{EPOCHS}")
    ce_sum, kl_sum, loss_sum, correct, total, n_batches = 0.0, 0.0, 0.0, 0, 0, 0
    for x, y in p:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        logits, mu, logvar = model(x)
        ce = F.cross_entropy(logits, y)
        kl = VIB.kl(mu, logvar).mean()
        loss = ce + BETA * kl
        opt.zero_grad()
        loss.backward()
        opt.step()
        p.set_postfix(loss=loss.item(), ce=ce.item(), kl=kl.item())
        ce_sum += ce.item()
        kl_sum += kl.item()
        loss_sum += loss.item()
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
        n_batches += 1
    ce_tr = ce_sum / n_batches
    kl_tr = kl_sum / n_batches
    loss_tr = loss_sum / n_batches
    acc_tr = correct / total
    train_ixz.append(kl_tr)
    train_iyz.append(H_Y - ce_tr)
    train_loss.append(loss_tr)
    train_acc.append(acc_tr)

    ce_v, kl_v, acc_v = evaluate(model, te_loader)
    val_ixz.append(kl_v)
    val_iyz.append(H_Y - ce_v)
    val_loss.append(ce_v + BETA * kl_v)
    val_acc.append(acc_v)

    print(f"val_acc={acc_v*100:.2f}% val_CE={ce_v:.4f} Ixz={kl_v:.4f} Iyz={H_Y - ce_v:.4f}")

# save teacher
torch.save({
    "model": model.state_dict(),
    "opt": opt.state_dict(),
    "ixz": train_ixz,
    "iyz": train_iyz,
    "loss": train_loss,
    "acc": train_acc,
}, "teacher.pt")

epoch 1/20: 100%|██████████| 469/469 [00:04<00:00, 108.70it/s, ce=0.666, kl=4.56, loss=1.12]


val_acc=87.60% val_CE=0.4793 Ixz=4.5962 Iyz=1.8233


epoch 2/20: 100%|██████████| 469/469 [00:04<00:00, 113.62it/s, ce=0.268, kl=4.82, loss=0.75]


val_acc=89.42% val_CE=0.3841 Ixz=4.8015 Iyz=1.9185


epoch 3/20: 100%|██████████| 469/469 [00:04<00:00, 113.06it/s, ce=0.385, kl=5.22, loss=0.907]


val_acc=90.77% val_CE=0.3281 Ixz=5.1142 Iyz=1.9744


epoch 4/20: 100%|██████████| 469/469 [00:04<00:00, 110.18it/s, ce=0.238, kl=4.81, loss=0.719]


val_acc=90.75% val_CE=0.3232 Ixz=4.8579 Iyz=1.9794


epoch 5/20: 100%|██████████| 469/469 [00:03<00:00, 118.11it/s, ce=0.235, kl=4.98, loss=0.733]


val_acc=91.23% val_CE=0.2938 Ixz=4.9091 Iyz=2.0088


epoch 6/20: 100%|██████████| 469/469 [00:04<00:00, 114.41it/s, ce=0.239, kl=4.87, loss=0.726]


val_acc=91.44% val_CE=0.2935 Ixz=4.8139 Iyz=2.0090


epoch 7/20: 100%|██████████| 469/469 [00:03<00:00, 120.40it/s, ce=0.302, kl=4.78, loss=0.78]


val_acc=90.93% val_CE=0.3033 Ixz=4.7021 Iyz=1.9993


epoch 8/20: 100%|██████████| 469/469 [00:04<00:00, 114.86it/s, ce=0.355, kl=4.64, loss=0.819]


val_acc=90.95% val_CE=0.3058 Ixz=4.4485 Iyz=1.9968


epoch 9/20: 100%|██████████| 469/469 [00:04<00:00, 114.20it/s, ce=0.274, kl=4.66, loss=0.741]


val_acc=92.02% val_CE=0.2691 Ixz=4.6689 Iyz=2.0334


epoch 10/20: 100%|██████████| 469/469 [00:03<00:00, 118.03it/s, ce=0.291, kl=4.68, loss=0.76]


val_acc=91.86% val_CE=0.2741 Ixz=4.5537 Iyz=2.0284


epoch 11/20: 100%|██████████| 469/469 [00:04<00:00, 115.37it/s, ce=0.297, kl=4.86, loss=0.783]


val_acc=92.77% val_CE=0.2536 Ixz=4.8415 Iyz=2.0489


epoch 12/20: 100%|██████████| 469/469 [00:04<00:00, 115.55it/s, ce=0.257, kl=4.71, loss=0.728]


val_acc=93.33% val_CE=0.2390 Ixz=4.6283 Iyz=2.0636


epoch 13/20: 100%|██████████| 469/469 [00:03<00:00, 118.25it/s, ce=0.114, kl=4.59, loss=0.574]


val_acc=92.82% val_CE=0.2573 Ixz=4.5191 Iyz=2.0453


epoch 14/20: 100%|██████████| 469/469 [00:04<00:00, 112.18it/s, ce=0.136, kl=4.55, loss=0.59]


val_acc=93.59% val_CE=0.2344 Ixz=4.4530 Iyz=2.0682


epoch 15/20: 100%|██████████| 469/469 [00:03<00:00, 117.51it/s, ce=0.104, kl=4.36, loss=0.539]


val_acc=93.36% val_CE=0.2353 Ixz=4.3354 Iyz=2.0673


epoch 16/20: 100%|██████████| 469/469 [00:04<00:00, 112.16it/s, ce=0.198, kl=4.34, loss=0.632]


val_acc=93.58% val_CE=0.2287 Ixz=4.3471 Iyz=2.0739


epoch 17/20: 100%|██████████| 469/469 [00:04<00:00, 115.88it/s, ce=0.177, kl=4.16, loss=0.593]


val_acc=93.88% val_CE=0.2228 Ixz=4.1512 Iyz=2.0798


epoch 18/20: 100%|██████████| 469/469 [00:03<00:00, 118.05it/s, ce=0.144, kl=4.32, loss=0.576]


val_acc=94.62% val_CE=0.2145 Ixz=4.1493 Iyz=2.0881


epoch 19/20: 100%|██████████| 469/469 [00:03<00:00, 117.93it/s, ce=0.0518, kl=4.25, loss=0.477]


val_acc=94.04% val_CE=0.2154 Ixz=4.1108 Iyz=2.0872


epoch 20/20: 100%|██████████| 469/469 [00:03<00:00, 119.43it/s, ce=0.111, kl=4.2, loss=0.531]


val_acc=94.74% val_CE=0.1992 Ixz=4.1287 Iyz=2.1034


We have two models $S(X, W_s) = D_s(z_s \sim \mathcal{N}(\mu_s(X), \log \sigma_s(X)))$ and $T(X, W_t) = D_t(z_t \sim \mathcal{N}(\mu_t(X), \log \sigma_t(X)))$.

We have that
$$
I(Z_t; Z_s) = \mathbb{E}[\log(p(Z_t|Z_s)) - \log(p(Z_t))] = \\ H(p(Z_t)) - \mathbb{E}[\log \left(\frac{p(Z_t|Z_s)}{q(Z_t|Z_s)}\right)] + \mathbb{E}[ \log q(Z_t|Z_s)] = \\ H(Z_t) + \mathbb{E}[ \log q(Z_t|Z_s)] + \mathbb{E}_{p(Z_s)} [D_{KL}(p(Z_t|Z_s) \| q(Z_t|Z_s))].
$$

Suppose that we want to maximize $I(Z_t; Z_s)$ (with $Z_s$ as a variable). That is equivalent to maximizing
$$
\mathbb{E}[ \log q(Z_t|Z_s)] + \mathbb{E}_{p(Z_s)} [D_{KL}(p(Z_t|Z_s) \| q(Z_t|Z_s))].
$$

Given the preceding equalities, we can just minimize $-\mathbb{E}[ \log q(Z_t|Z_s)] = O(s, q)$ by jointly optimizing $Z_s$ and $q$.

It is also easy to see that $O(s, q) \geq H(Z_t|Z_s)$, so by minimizing $O$, we minimize the entropy (we assume that $q$ has good representation capabilities).

Note that those variables are gaussian and nondeterministic, so that we can have nice behavior. How can that be related to the IBP? If we assume it is true, and our teacher is nice, then we can say that our teacher is good at the objective
$$
\max_{Z_t} I(Y; Z_t) - \beta I(X; Z_t).
$$

We want our student to also be good at it. We can cast that as wanting to minimize
$$
|I(Y; Z_t) - I(Y; S_t)|,\\
|I(X; Z_t) - I(Y; S_t)|.
$$

We claim that maximizing $I(Z_t; Z_s)$ can achieve that. Note that
$$
I(Y; Z_t, Z_s) = I(Y; Z_t | Z_s) + I(Y; Z_s) \\
I(Y; Z_t, Z_s) = I(Y; Z_s | Z_t) + I(Y; Z_t)
$$
so
$$
I(Y; Z_s) = I(Y; Z_t, Z_s) - I(Y; Z_t | Z_s) \\
I(Y; Z_t) = I(Y; Z_t, Z_s) - I(Y; Z_s | Z_t)
$$
hence
$$
I(Y; Z_t)  - I(Y; Z_s)   = I(Y; Z_s, Z_t) - I(Y; Z_s | Z_t) - I(Y; Z_s, Z_t) + I(Y; Z_t | Z_s) = \\ I(Y; Z_t | Z_s) - I(Y; Z_s | Z_t) \leq I(Y; Z_t | Z_s) = H(Z_t|Z_s) - H(Z_t|Y, Z_s) \leq H(Z_t|Z_s).
$$
(the last equality assumes positive entropy. Can I do that? I think so, as Z_s has noise).

Same can be done with $X$. The thing is that this can just make $Z_s$'s entropy/MI with $X$ and $Y$ big and miss the point.


In [8]:
import torch, torch.nn as nn, torch.nn.functional as F
from tqdm import tqdm

DEVICE = globals().get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 100 #globals().get("EPOCHS", 20)
IN_DIM = globals()["IN_DIM"]
NUM_CLASSES = globals()["NUM_CLASSES"]
Z_DIM = globals()["Z_DIM"]
Z_DIM_STUDENT = globals().get("Z_DIM_STUDENT", Z_DIM)
LR_STUDENT = globals().get("LR_STUDENT", 1e-3)
LR_Q = globals().get("LR_Q", 1e-3)
BETA = globals().get("BETA", 1e-2)
EPS = 1e-5

def collect_mu_y(model, loader, device=DEVICE, max_batches=None):
    model.eval(); mus, ys = [], []; b = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            _, mu, _ = model(x)
            mus.append(mu.cpu()); ys.append(y.cpu()); b += 1
            if (max_batches is not None) and (b >= max_batches): break
    return torch.cat(mus, 0), torch.cat(ys, 0)

def collect_mu(model, loader, **kw): return collect_mu_y(model, loader, **kw)[0]

def collect_mu_y_all(model, loader, device=DEVICE):
    model.eval(); Zs, Ys = [], []
    for x, y in loader:
        x = x.to(device)
        _, mu, _ = model(x)
        Zs.append(mu.detach().cpu()); Ys.append(y.detach().cpu())
    return torch.cat(Zs,0), torch.cat(Ys,0)

def fit_ridge_classifier(Z, y, num_classes, lam=1e-2):
    # Z: [N,d], y: [N], returns W: [d,C], b: [C]
    N, d = Z.shape
    C = num_classes
    Y = torch.zeros(N, C, dtype=Z.dtype)
    Y[torch.arange(N), y] = 1.0
    Z1 = torch.cat([Z, torch.ones(N,1,dtype=Z.dtype)], dim=1)     # add bias
    I = torch.eye(d+1, dtype=Z.dtype); I[-1,-1] = 0.0             # no reg on bias
    A = Z1.T @ Z1 + lam * I
    Wb = torch.linalg.solve(A, Z1.T @ Y)                          # (d+1) x C
    W, b = Wb[:-1], Wb[-1]
    return W, b

@torch.no_grad()
def probe_accuracy(student, tr_loader, te_loader, num_classes, lam=1e-2):
    Z_tr, y_tr = collect_mu_y_all(student, tr_loader, device=DEVICE)
    Z_te, y_te = collect_mu_y_all(student, te_loader, device=DEVICE)
    W, b = fit_ridge_classifier(Z_tr, y_tr, num_classes, lam=lam)
    logits_tr = Z_tr @ W + b
    logits_te = Z_te @ W + b
    acc_tr = (logits_tr.argmax(1) == y_tr).float().mean().item()
    acc_te = (logits_te.argmax(1) == y_te).float().mean().item()
    return acc_tr, acc_te


@torch.no_grad()
def estimate_mi_gaussian_xy(A, B, eps=1e-5):
    X = torch.cat([A, B], dim=1).double()
    A = A.double(); B = B.double()
    def cov(m):
        c = torch.cov(m.T)
        return c + eps * torch.eye(c.shape[0], dtype=c.dtype)
    SA = cov(A); SB = cov(B); S = cov(X)
    _, logdetA = torch.linalg.slogdet(SA)
    _, logdetB = torch.linalg.slogdet(SB)
    _, logdetS = torch.linalg.slogdet(S)
    return 0.5 * (logdetA + logdetB - logdetS).item()

@torch.no_grad()
def estimate_mi_teacher_student_on_loader(teacher, student, loader, eps=1e-5, max_batches=None):
    zT = collect_mu(teacher, loader, device=DEVICE, max_batches=max_batches)
    zS = collect_mu(student, loader, device=DEVICE, max_batches=max_batches)
    return estimate_mi_gaussian_xy(zT, zS, eps=eps)

@torch.no_grad()
def estimate_Ixz_upper(model, loader, device=DEVICE, max_batches=None):
    model.eval(); kls = []; b = 0
    for x, _ in loader:
        x = x.to(device)
        _, mu, logvar = model(x)
        var = logvar.exp()
        kl = 0.5 * (mu.pow(2) + var - logvar - 1.0).sum(dim=1)
        kls.append(kl.detach().cpu()); b += 1
        if (max_batches is not None) and (b >= max_batches): break
    return torch.cat(kls, 0).mean().item()

@torch.no_grad()
def estimate_Iyz_gaussian_plugin(model, loader, num_classes, eps=1e-5, max_batches=None):
    Z, Y = collect_mu_y(model, loader, device=DEVICE, max_batches=max_batches)
    Z = Z.double(); d = Z.shape[1]
    def cov_eps(M):
        C = torch.cov(M.T)
        return C + eps * torch.eye(d, dtype=C.dtype)
    Sig = cov_eps(Z); _, logdet = torch.linalg.slogdet(Sig)
    H = 0.5 * (logdet + d * torch.log(torch.tensor(2.0 * torch.pi * torch.e, dtype=logdet.dtype)))
    N = float(Z.shape[0]); Hy_avg = 0.0
    for y in range(num_classes):
        idx = (Y == y).nonzero(as_tuple=False).squeeze(-1)
        if idx.numel() == 0: continue
        Zy = Z[idx]
        if Zy.shape[0] < d + 1:
            var = Zy.var(dim=0, unbiased=True) + eps
            logdet_y = torch.log(var).sum()
        else:
            Sigy = cov_eps(Zy); _, logdet_y = torch.linalg.slogdet(Sigy)
        Hy = 0.5 * (logdet_y + d * torch.log(torch.tensor(2.0 * torch.pi * torch.e, dtype=logdet.dtype)))
        py = idx.numel() / N
        Hy_avg += py * Hy
    return (H - Hy_avg).item()

@torch.no_grad()
def accuracy_on_loader(model, loader, device=DEVICE):
    model.eval(); correct = total = 0
    for x, y in loader:
        x = x.to(device); y = y.to(device)
        logits, _, _ = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item(); total += y.size(0)
    return correct / max(total, 1)

class ConditionalGaussian(nn.Module):
    def __init__(self, z_s_dim, z_t_dim, hidden=256):
        super().__init__()
        self.backbone = nn.Sequential(nn.Linear(z_s_dim, hidden), nn.ReLU(inplace=True),
                                      nn.Linear(hidden, hidden), nn.ReLU(inplace=True))
        self.mu = nn.Linear(hidden, z_t_dim)
        self.logvar = nn.Linear(hidden, z_t_dim)
    def forward(self, z_s):
        h = self.backbone(z_s)
        mu = self.mu(h)
        logvar = self.logvar(h).clamp(-10.0, 10.0)
        return mu, logvar

def gaussian_nll(z, mu, logvar):
    var = logvar.exp()
    nll = 0.5 * (((z - mu) ** 2) / var + logvar + torch.log(torch.tensor(2.0 * torch.pi, device=z.device))).sum(dim=1)
    return nll.mean()

def kl_standard_normal(mu, logvar):
    var = logvar.exp()
    return 0.5 * (mu.pow(2) + var - logvar - 1.0).sum(dim=1).mean()

ckpt = torch.load("teacher.pt", map_location=DEVICE)
teacher = VIB(IN_DIM, Z_DIM, NUM_CLASSES).to(DEVICE)
teacher.load_state_dict(ckpt["model"])
teacher.eval()


for p in teacher.parameters(): p.requires_grad_(False)
# copy teacher's decoder to student
student = VIB(IN_DIM, Z_DIM_STUDENT, NUM_CLASSES).to(DEVICE)
student.classifier.load_state_dict(teacher.classifier.state_dict())
q_head  = ConditionalGaussian(Z_DIM_STUDENT, Z_DIM).to(DEVICE)
opt = torch.optim.Adam([
    {"params": student.parameters(), "lr": LR_STUDENT},
    {"params": q_head.parameters(),  "lr": LR_Q},
])

print("=== Teacher reference (before training) ===")
t_tr_acc = accuracy_on_loader(teacher, tr_loader)
t_te_acc = accuracy_on_loader(teacher, te_loader)
t_tr_Ixz = estimate_Ixz_upper(teacher, tr_loader)
t_te_Ixz = estimate_Ixz_upper(teacher, te_loader)
t_tr_Iyz = estimate_Iyz_gaussian_plugin(teacher, tr_loader, NUM_CLASSES, eps=EPS)
t_te_Iyz = estimate_Iyz_gaussian_plugin(teacher, te_loader,  NUM_CLASSES, eps=EPS)
print(f"acc_train={t_tr_acc*100:.2f}% acc_val={t_te_acc*100:.2f}% | Ixz_train={t_tr_Ixz:.4f} Ixz_val={t_te_Ixz:.4f} | Iyz_train={t_tr_Iyz:.4f} Iyz_val={t_te_Iyz:.4f}")

mi_train_teacher_student, mi_val_teacher_student = [], []
ixz_train, ixz_val, iyz_train, iyz_val = [], [], [], []
acc_train_hist, acc_val_hist = [], []

for epoch in range(1, EPOCHS + 1):
    student.train(); q_head.train()
    pbar = tqdm(tr_loader, desc=f"[MI + beta-Ixz] epoch {epoch}/{EPOCHS}")
    loss_sum = 0.0; n_batches = 0
    for x, y in pbar:
        if epoch == 1: break # measure things before train strart
        x = x.to(DEVICE)
        with torch.no_grad():
            _, mu_T, logvar_T = teacher(x)
            z_t = mu_T + (0.5 * logvar_T).exp() * torch.randn_like(mu_T)
        _, mu_S, logvar_S = student(x)
        z_s = mu_S + (0.5 * logvar_S).exp() * torch.randn_like(mu_S)
        mu_q, logvar_q = q_head(z_s)
        nll = gaussian_nll(z_t, mu_q, logvar_q)
        kl_ixz = kl_standard_normal(mu_S, logvar_S)
        loss = nll + BETA * kl_ixz
        opt.zero_grad(); loss.backward(); opt.step()
        loss_sum += float(loss.item()); n_batches += 1
        pbar.set_postfix(loss=f"{loss.item():.3f}", nll=f"{nll.item():.3f}", kl=f"{kl_ixz.item():.3f}")
    train_loss_epoch = loss_sum / max(n_batches, 1)

    student.eval(); q_head.eval()
    mi_tr = estimate_mi_teacher_student_on_loader(teacher, student, tr_loader, eps=EPS)
    mi_v  = estimate_mi_teacher_student_on_loader(teacher, student, te_loader, eps=EPS)
    ixz_tr = estimate_Ixz_upper(student, tr_loader)
    ixz_v  = estimate_Ixz_upper(student, te_loader)
    iyz_tr = estimate_Iyz_gaussian_plugin(student, tr_loader, NUM_CLASSES, eps=EPS)
    iyz_v  = estimate_Iyz_gaussian_plugin(student, te_loader,  NUM_CLASSES, eps=EPS)
    acc_tr, acc_te = probe_accuracy(student, tr_loader, te_loader, NUM_CLASSES, lam=1e-2)



    mi_train_teacher_student.append(mi_tr); mi_val_teacher_student.append(mi_v)
    ixz_train.append(ixz_tr); ixz_val.append(ixz_v)
    iyz_train.append(iyz_tr); iyz_val.append(iyz_v)
    acc_train_hist.append(acc_tr); acc_val_hist.append(acc_v)

    print(f"[epoch {epoch}] loss={train_loss_epoch:.4f} | beta={BETA:.4f} | "
          f"probe_acc_tr={acc_tr*100:.2f}% probe_acc_val={acc_te*100:.2f}% | "
          f"MI_tr={mi_tr:.4f} MI_val={mi_v:.4f} | Ixz_tr={ixz_tr:.4f} Ixz_val={ixz_v:.4f} | "
          f"Iyz_tr={iyz_tr:.4f} Iyz_val={iyz_v:.4f}")
torch.save({
    "student": student.state_dict(),
    "q_head": q_head.state_dict(),
    "mi_train_teacher_student": mi_train_teacher_student,
    "mi_val_teacher_student": mi_val_teacher_student,
    "Ixz_train": ixz_train,  "Ixz_val": ixz_val,
    "Iyz_train": iyz_train,  "Iyz_val": iyz_val,
    "acc_train": acc_train_hist, "acc_val": acc_val_hist,
    "beta": BETA,
    "z_dim_student": Z_DIM_STUDENT,
}, "student_mi_fixedbeta.pt")

print("Saved to student_mi_fixedbeta.pt")


=== Teacher reference (before training) ===
acc_train=96.74% acc_val=94.70% | Ixz_train=4.1696 Ixz_val=4.1287 | Iyz_train=34.0816 Iyz_val=32.0867


[MI + beta-Ixz] epoch 1/100:   0%|          | 0/469 [00:00<?, ?it/s]


[epoch 1] loss=0.0000 | beta=0.1000 | probe_acc_tr=76.85% probe_acc_val=77.03% | MI_tr=0.0249 MI_val=4.2134 | Ixz_tr=0.1109 Ixz_val=0.1116 | Iyz_tr=10.7890 Iyz_val=11.9203


[MI + beta-Ixz] epoch 2/100: 100%|██████████| 469/469 [00:04<00:00, 108.68it/s, kl=2.645, loss=91.362, nll=91.097]
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e40f2d393a0>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e40f2d393a0>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 1

[epoch 2] loss=90.9441 | beta=0.1000 | probe_acc_tr=90.86% probe_acc_val=90.87% | MI_tr=0.0229 MI_val=10.4558 | Ixz_tr=3.2203 Ixz_val=3.2560 | Iyz_tr=31.6276 Iyz_val=32.8686


[MI + beta-Ixz] epoch 3/100:   0%|          | 1/469 [00:00<00:47,  9.94it/s, kl=3.734, loss=90.881, nll=90.507]Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e40f2d393a0><function _MultiProcessingDataLoaderIter.__del__ at 0x7e40f2d393a0>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, 

[epoch 3] loss=89.8039 | beta=0.1000 | probe_acc_tr=95.92% probe_acc_val=95.82% | MI_tr=0.0245 MI_val=14.5709 | Ixz_tr=6.8249 Ixz_val=6.8719 | Iyz_tr=40.7864 Iyz_val=43.0284


[MI + beta-Ixz] epoch 4/100: 100%|██████████| 469/469 [00:03<00:00, 118.01it/s, kl=7.005, loss=88.287, nll=87.587]


[epoch 4] loss=88.7611 | beta=0.1000 | probe_acc_tr=97.46% probe_acc_val=96.99% | MI_tr=0.0248 MI_val=16.1854 | Ixz_tr=6.8957 Ixz_val=6.8879 | Iyz_tr=41.9783 Iyz_val=43.6861


[MI + beta-Ixz] epoch 5/100: 100%|██████████| 469/469 [00:04<00:00, 109.91it/s, kl=6.183, loss=88.829, nll=88.210]


[epoch 5] loss=88.3161 | beta=0.1000 | probe_acc_tr=98.00% probe_acc_val=97.13% | MI_tr=0.0259 MI_val=15.9927 | Ixz_tr=5.9953 Ixz_val=5.9931 | Iyz_tr=40.8467 Iyz_val=42.4801


[MI + beta-Ixz] epoch 6/100: 100%|██████████| 469/469 [00:04<00:00, 108.96it/s, kl=5.356, loss=88.550, nll=88.014]


[epoch 6] loss=88.0870 | beta=0.1000 | probe_acc_tr=98.37% probe_acc_val=97.70% | MI_tr=0.0245 MI_val=16.0156 | Ixz_tr=5.4543 Ixz_val=5.4433 | Iyz_tr=39.2185 Iyz_val=40.3529


[MI + beta-Ixz] epoch 7/100: 100%|██████████| 469/469 [00:04<00:00, 106.05it/s, kl=5.288, loss=87.831, nll=87.302]


[epoch 7] loss=88.0023 | beta=0.1000 | probe_acc_tr=98.58% probe_acc_val=97.73% | MI_tr=0.0258 MI_val=15.2419 | Ixz_tr=5.3170 Ixz_val=5.2936 | Iyz_tr=36.3153 Iyz_val=37.2967


[MI + beta-Ixz] epoch 8/100: 100%|██████████| 469/469 [00:04<00:00, 111.38it/s, kl=5.111, loss=87.473, nll=86.962]


[epoch 8] loss=87.9755 | beta=0.1000 | probe_acc_tr=98.65% probe_acc_val=97.64% | MI_tr=0.0248 MI_val=15.3209 | Ixz_tr=5.1878 Ixz_val=5.1765 | Iyz_tr=35.9893 Iyz_val=36.9982


[MI + beta-Ixz] epoch 9/100: 100%|██████████| 469/469 [00:04<00:00, 111.18it/s, kl=5.074, loss=87.891, nll=87.384]


[epoch 9] loss=87.8800 | beta=0.1000 | probe_acc_tr=98.69% probe_acc_val=97.67% | MI_tr=0.0253 MI_val=14.8996 | Ixz_tr=5.0750 Ixz_val=5.0555 | Iyz_tr=35.8760 Iyz_val=36.3373


[MI + beta-Ixz] epoch 10/100: 100%|██████████| 469/469 [00:04<00:00, 109.94it/s, kl=4.891, loss=87.150, nll=86.661]


[epoch 10] loss=87.8857 | beta=0.1000 | probe_acc_tr=98.86% probe_acc_val=97.69% | MI_tr=0.0249 MI_val=15.2010 | Ixz_tr=4.8836 Ixz_val=4.8621 | Iyz_tr=35.7140 Iyz_val=36.2916


[MI + beta-Ixz] epoch 11/100: 100%|██████████| 469/469 [00:04<00:00, 111.64it/s, kl=5.087, loss=87.012, nll=86.503]


[epoch 11] loss=87.8694 | beta=0.1000 | probe_acc_tr=98.83% probe_acc_val=97.70% | MI_tr=0.0241 MI_val=15.2494 | Ixz_tr=4.9737 Ixz_val=4.9642 | Iyz_tr=35.2638 Iyz_val=35.9471


[MI + beta-Ixz] epoch 12/100: 100%|██████████| 469/469 [00:04<00:00, 109.76it/s, kl=5.015, loss=88.565, nll=88.064]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt

