In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

batch_size = 128
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("Training samples:", len(train_dataset))
print("Test samples:", len(test_dataset))

Training samples: 60000
Test samples: 10000


In [2]:
def estimate_mi_binned(y, c, n_bins=10, vmin=-3.0, vmax=3.0, n_classes=10):
    """
    Returns I(y; c) by naive histogram binning.
    We do NOT attempt a real gradient wrt y here
    (the gradient is effectively zero except on bin boundaries).
    """
    B = y.size(0)
    device = y.device
    
    edges = torch.linspace(vmin, vmax, n_bins+1, device=device)
    bin_idx = torch.bucketize(y, edges) - 1
    bin_idx = torch.clamp(bin_idx, 0, n_bins-1)
    
    joint_counts = torch.zeros(n_bins, n_classes, device=device)
    for i in range(B):
        b = bin_idx[i].item()
        cc = c[i].item()
        joint_counts[b, cc] += 1
    joint_prob = joint_counts / float(B)
    
    p_bin = joint_prob.sum(dim=1, keepdim=True)  # (n_bins,1)
    p_c   = joint_prob.sum(dim=0, keepdim=True)  # (1,n_classes)
    mask = (joint_prob > 0)
    p_bc = joint_prob[mask]
    
    bin_idxs, class_idxs = mask.nonzero(as_tuple=True)
    p_b = p_bin[bin_idxs]
    p_cc = p_c[0, class_idxs]
    
    mi_val = (p_bc * (torch.log(p_bc) - torch.log(p_b) - torch.log(p_cc))).sum()
    return mi_val

In [3]:
class LocalLayerBlock(nn.Module):
    """
    A 'layer' with K scalar neurons: y_i = w_i^T x + b_i, i=1..K.
    We'll do local updates for each neuron (though we store them in one linear).
    """
    def __init__(self, in_dim, K=10):
        super().__init__()
        self.linear = nn.Linear(in_dim, K, bias=True)  # W shape: (K, in_dim)
    
    def forward(self, x):
        # x: (B, in_dim)
        # Y: (B, K)
        Y = self.linear(x)
        return Y
    
    def compute_redundancy(self, i, j):
        """
        Dot product of w_i and w_j for i != j.
        """
        W = self.linear.weight  # shape: (K, in_dim)
        wi = W[i,:]
        wj = W[j,:]
        return torch.dot(wi, wj)
    
    def forward_neuron(self, x, i):
        """
        Return the scalar y_i for the i-th neuron only.
        """
        # y_i = w_i^T x + b_i
        # We can do a slice of the linear layer
        W = self.linear.weight[i,:].unsqueeze(0)   # shape (1, in_dim)
        b = self.linear.bias[i].unsqueeze(0)       # shape (1,)
        # forward => y shape (B,1)
        y = torch.matmul(x, W.t()) + b
        return y.squeeze(1)  # (B,)

In [4]:
def local_update_layer(
    layer_block,
    train_loader,
    lr=1e-3,
    lambda_reg=0.01,
    epochs=3,
    n_bins=10,
    device='cpu'
):
    """
    A 'local' update rule:
      For each mini-batch (X,c):
        For i in [0..K-1]:
          1) y_i = w_i^T X + b_i
          2) L_i = -I(y_i; c) + lambda * sum_{j != i} dot(w_i, w_j)
          3) compute grad(L_i) w.r.t. w_i, update w_i only
    """
    layer_block.to(device)
    optimizer = torch.optim.SGD(layer_block.parameters(), lr=lr)
    
    for epoch in range(epochs):
        layer_block.train()
        total_loss = 0.0
        
        for X, labels in train_loader:
            X = X.view(X.size(0), -1).to(device)
            labels = labels.to(device)
            
            K_ = layer_block.linear.weight.shape[0]
            
            # We'll do one pass for each neuron i
            for i in range(K_):
                optimizer.zero_grad()
                
                # compute y_i
                y_i = layer_block.forward_neuron(X, i)  # shape (B,)
                
                # estimate I(y_i;c)
                mi_i = estimate_mi_binned(y_i, labels, n_bins=n_bins)
                
                # redundancy term: sum_{j != i} w_i dot w_j
                # we can read w_i from layer_block.linear.weight[i,:]
                # let's do a quick sum
                w_i = layer_block.linear.weight[i,:]
                red_sum = 0.0
                for j in range(K_):
                    if j != i:
                        w_j = layer_block.linear.weight[j,:]
                        red_sum += torch.dot(w_i, w_j)
                
                # loss_i = -mi_i + lambda * red_sum
                # Minimizing => we want to minimize negative MI => maximizing MI
                # plus the redundancy penalty
                loss_i = -mi_i + lambda_reg * red_sum
                
                # local backward pass 
                # By default, PyTorch will compute grad for *all* parameters
                loss_i.backward(retain_graph=True)
                
                # Now we want to "zero out" gradient for weights that are not w_i or b_i
                # Let's do that:
                with torch.no_grad():
                    for name, param in layer_block.named_parameters():
                        if 'weight' in name:
                            # param shape (K, in_dim)
                            # zero out rows != i
                            grad_ = param.grad
                            if grad_ is not None:
                                for row in range(K_):
                                    if row != i:
                                        grad_[row,:] = 0.0
                        elif 'bias' in name:
                            # param shape (K,)
                            grad_ = param.grad
                            if grad_ is not None:
                                for row in range(K_):
                                    if row != i:
                                        grad_[row] = 0.0
                
                # take an update step 
                optimizer.step()
                
                total_loss += loss_i.item()
        
        # average loss over #batches*K
        num_updates = len(train_loader) * K_
        avg_loss = total_loss / float(num_updates)
        print(f"Epoch [{epoch+1}/{epochs}] - Avg Local Loss: {avg_loss:.4f}")
    
    return layer_block

In [5]:
def freeze_module(module):
    for param in module.parameters():
        param.requires_grad = False

def get_hidden_representation(layer_block, data_loader):
    layer_block.eval()
    all_feats = []
    all_labels = []
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.view(images.size(0), -1).to(device)
            Y = layer_block(images)  # (B, K)
            all_feats.append(Y.cpu())
            all_labels.append(labels)
    return torch.cat(all_feats, dim=0), torch.cat(all_labels, dim=0)

def build_and_train_2layer_local_mi(train_loader, test_loader, device=device):
    # 1) Create layer1
    layer1 = LocalLayerBlock(in_dim=784, K=16)
    # local training
    layer1 = local_update_layer(layer1, train_loader, lr=1e-3, lambda_reg=0.01, epochs=2, n_bins=10, device=device)
    
    # 2) Freeze
    freeze_module(layer1)
    
    # 3) Extract hidden rep
    X_train_hid, y_train = get_hidden_representation(layer1, train_loader)
    X_test_hid, y_test   = get_hidden_representation(layer1, test_loader)
    
    # Make new data loaders
    train_hid_ds = torch.utils.data.TensorDataset(X_train_hid, y_train)
    test_hid_ds  = torch.utils.data.TensorDataset(X_test_hid, y_test)
    train_hid_loader = torch.utils.data.DataLoader(train_hid_ds, batch_size=128, shuffle=True)
    test_hid_loader  = torch.utils.data.DataLoader(test_hid_ds, batch_size=128, shuffle=False)
    
    # 4) Create layer2
    layer2 = LocalLayerBlock(in_dim=16, K=16)
    # local training
    layer2 = local_update_layer(layer2, train_hid_loader, lr=1e-3, lambda_reg=0.01, epochs=2, n_bins=10, device=device)
    
    return layer1, layer2, train_hid_loader, test_hid_loader

In [6]:
class FinalClassifier(nn.Module):
    def __init__(self, in_dim=16, num_classes=10):
        super().__init__()
        self.linear = nn.Linear(in_dim, num_classes)
    
    def forward(self, x):
        return self.linear(x)

def evaluate_final_representation(layer2, train_loader, test_loader):
    freeze_module(layer2)
    
    # Extract final hidden
    X_train_final, y_train = get_hidden_representation(layer2, train_loader)
    X_test_final,  y_test  = get_hidden_representation(layer2, test_loader)
    
    train_ds = torch.utils.data.TensorDataset(X_train_final, y_train)
    test_ds  = torch.utils.data.TensorDataset(X_test_final, y_test)
    
    train_loader2 = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True)
    test_loader2  = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=False)
    
    # Train a standard linear classifier
    clf = FinalClassifier(in_dim=X_train_final.shape[1], num_classes=10).to(device)
    opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
    ce_loss = nn.CrossEntropyLoss()
    
    for epoch in range(5):
        clf.train()
        total_loss = 0.0
        for feats, labs in train_loader2:
            feats = feats.to(device)
            labs = labs.to(device)
            opt.zero_grad()
            logits = clf(feats)
            loss = ce_loss(logits, labs)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader2)
        print(f"Linear Classifier Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}")
    
    # Evaluate
    clf.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for feats, labs in test_loader2:
            feats = feats.to(device)
            labs = labs.to(device)
            logits = clf(feats)
            preds = logits.argmax(dim=1)
            correct += (preds == labs).sum().item()
            total += labs.size(0)
    acc = 100.0 * correct / total
    return acc

In [7]:
layer1_local, layer2_local, train_hid_loader, test_hid_loader = build_and_train_2layer_local_mi(train_loader, test_loader, device=device)
acc_local = evaluate_final_representation(layer2_local, train_hid_loader, test_hid_loader)
print(f"Final test accuracy (local MI rule): {acc_local:.2f}%")

KeyboardInterrupt: 

In [30]:
class TwoLayerNet(nn.Module):
    def __init__(self, in_dim=784, hidden_dim=64, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [31]:
def train_backprop(model, train_loader, lr=1e-3, epochs=5, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    ce_loss = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images = images.view(images.size(0), -1).to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            logits = model(images)
            loss = ce_loss(logits, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {running_loss/len(train_loader):.4f}")

def evaluate_accuracy(model, data_loader, device='cpu'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.view(images.size(0), -1).to(device)
            labels = labels.to(device)
            logits = model(images)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total

In [None]:
model_bp = TwoLayerNet(in_dim=784, hidden_dim=64, num_classes=10)
train_backprop(model_bp, train_loader, lr=1e-3, epochs=5, device=device)

test_acc_bp = evaluate_accuracy(model_bp, test_loader, device=device)
print(f"Standard Backprop 2-layer test accuracy: {test_acc_bp:.2f}%")