In [888]:
import torch
import torch.nn.functional as F

In [889]:
# Assuming the shapes are as follows:
# A: [B, 1, m, n] - Tensor A
# B: [1, B, n, m] - Tensor B
B, m, n = 3, 4, 5  # Example dimensions
# Create dummy data for demonstration
A = torch.randn(B, 1, m, n)
B = torch.randn(1, B, n, m)
# Reshape A to [B, m, n] by squeezing out the singleton dimension
A = A.squeeze(1)
# Reshape B to [B, n, m] by squeezing and then permute to align properly
B = B.squeeze(0)
# Compute the batch matrix multiplication
# We want to compute A[i] @ B[j] for all i, j
# Expand dimensions of A and B for broadcasting to [B, 1, m, n] and [1, B, n, m]
A_expanded = A.unsqueeze(1)  # Shape [B, 1, m, n]
B_expanded = B.unsqueeze(0)  # Shape [1, B, n, m]
# Perform batch matrix multiplication resulting in [B, B, m, m]
result = torch.matmul(A_expanded, B_expanded)  # matmul auto-broadcasts the middle dimensions
# result has shape [B, B, m, m]
print(result.shape)  # Should print torch.Size([B, B, m, m])

torch.Size([3, 3, 4, 4])


In [1142]:
import torch
import torch.nn.functional as F
#fix seed
torch.manual_seed(0)
batch_size, d, num_classes, num_envs = 50, 200, 2, 2
x = torch.randn(batch_size, d)
y = torch.randint(0, num_classes, (batch_size,))
logits = torch.randn(batch_size, num_classes)

# suppose there are 3 environments, each samples belong to one of the 4 environments
envs = torch.randint(0, num_envs, (batch_size,))




In [1143]:
def hessian(x, logits):
    batch_size, d = x.shape  # Shape: [batch_size, d]
    num_classes = logits.shape[1]  # Number of classes
    dC = num_classes * d  # Total number of parameters in the flattened gradient

    # Compute probabilities
    p = F.softmax(logits, dim=1)  # Shape: [batch_size, num_classes]
    # p[i] is the logits for the i-th example
    
    # Compute p_k(1-p_k) for diagonal blocks and -p_k*p_l for off-diagonal blocks
    # Diagonal part
    p_diag = p * (1 - p) # Shape: [batch_size, num_classes]
    # Off-diagonal part
    p_off_diag = -p.unsqueeze(2) * p.unsqueeze(1)  # Shape: [batch_size, num_classes, num_classes]
    # Fill the diagonal part in off-diagonal tensor
    indices = torch.arange(num_classes)
    p_off_diag[:, indices, indices] = p_diag
    

    # Outer product of x
    X_outer = torch.einsum('bi,bj->bij', x, x)  # Shape: [batch_size, d, d]
    H2 = torch.einsum('bkl,bij->bklij', p_off_diag, X_outer) 
    H2 = H2.sum(0).reshape(dC, dC) / batch_size # Shape: [dC, dC]
    H = torch.zeros(dC,dC)
    for i in range(batch_size):
        H += torch.kron(p_off_diag[i], X_outer[i])
    H /= batch_size
    return  H, H2

In [1001]:
envs_hessian = torch.zeros(num_envs, d * num_classes, d * num_classes)
envs_hessian2 = torch.zeros(num_envs, d * num_classes, d * num_classes)
sum_hessian = torch.zeros(d * num_classes, d * num_classes)
for e in range(num_envs):
    env_mask = envs == e
    x_env = x[env_mask]
    logits_env = logits[env_mask]
    H, H2 = hessian(x_env, logits_env)
    # print(torch.norm(H, p='fro'))
    envs_hessian[e] = H
    envs_hessian2[e] = H2
    sum_hessian += H
avg_hessian = sum_hessian / num_envs

# hess_penalty1 = 0
hess_penalty2 = 0
for e in range(num_envs):
    env_freq = (envs == e).sum() / batch_size
    hess_penalty2 += torch.norm(envs_hessian[e] - avg_hessian, p='fro') ** 2 * num_envs ** (-1)


In [1002]:
hess_penalty2

tensor(0.0014)

In [1003]:
for e in range(num_envs):
    print(envs_hessian2[e].norm(p = 'fro')** 2) 

tensor(0.3344)
tensor(0.2968)


In [1144]:
import time
def hessian_pen(logits, x, envs, num_envs):
    
    p = F.softmax(logits, dim=1)

    diag = torch.diag_embed(p)
     
    off_diag = torch.einsum('bi,bj->bij', p, p)
    
    diff = diag - off_diag
    prob_trace = torch.einsum('bik,cjk->bcij', diff, diff).diagonal(dim1=-2, dim2=-1).sum(-1)
    X_outer = torch.einsum('bi,bj->bij', x, x)
    # x_traces = torch.einsum('bik,cjk->bcij', X_outer, X_outer).diagonal(dim1=-2, dim2=-1).sum(-1)
    x_traces = torch.zeros(batch_size, batch_size)
    for i in range(batch_size):
        for j in range(i, batch_size):
            x_traces[i, j] = torch.matmul(X_outer[i], X_outer[j]).trace()
            x_traces[j, i] = x_traces[i, j]
    
    
    env_indices = torch.arange(num_envs).unsqueeze(1)  # Shape (num_envs, 1)
    masks = env_indices == envs
    
    product_matrix = prob_trace * x_traces
    denoms = masks.sum(1).unsqueeze(1) * masks.sum(1).unsqueeze(0)
    mask1_expanded = masks.unsqueeze(1).unsqueeze(3)  # Shape (num_envs, 1, num_samples, 1)
    mask2_expanded = masks.unsqueeze(0).unsqueeze(2)  # Shape (1, num_envs, 1, num_samples)
    pairwise_masks = mask1_expanded & mask2_expanded
    
    masked_products = pairwise_masks * product_matrix.unsqueeze(0).unsqueeze(0)
    H_H_f = masked_products.sum(dim=-1).sum(dim=-1) / denoms

    f_norm_env = H_H_f.diagonal()
    shared_term = H_H_f.sum() / (num_envs ** 2)
    individual_term = 2 * H_H_f.sum(dim=1) / num_envs
    sum_h_minus_h_bar_sq = torch.sum(f_norm_env + shared_term - individual_term) / num_envs
    sum_h_minus_h_bar_sq /= (d * num_classes)
    
    return f_norm_env, sum_h_minus_h_bar_sq, H_H_f


In [1145]:
import time
unique_envs = envs.unique()
num_envs = len(unique_envs)
H_H_f = torch.zeros(num_envs, num_envs)
for e1 in unique_envs:
    for e2 in unique_envs:
        mask1 = envs == e1
        mask2 = envs == e2
        x_env1 = x[mask1]
        x_env2 = x[mask2]
        logits_env1 = logits[mask1]
        logits_env2 = logits[mask2]
        p1 = F.softmax(logits_env1, dim=1)
        p2 = F.softmax(logits_env2, dim=1)
        diag1 = torch.diag_embed(p1)
        diag2 = torch.diag_embed(p2)
        off_diag1 = torch.einsum('bi,bj->bij', p1, p1)
        off_diag2 = torch.einsum('bi,bj->bij', p2, p2)
        diff1 = diag1 - off_diag1
        diff2 = diag2 - off_diag2
        prob_trace_1_2 = torch.einsum('bik,cjk->bcij', diff1, diff2).diagonal(dim1=-2, dim2=-1).sum(-1)
        X_outer1 = torch.einsum('bi,bj->bij', x_env1, x_env1)
        X_outer2 = torch.einsum('bi,bj->bij', x_env2, x_env2)
        x_traces_1_2 = torch.einsum('bik,cjk->bcij', X_outer1, X_outer2).diagonal(dim1=-2, dim2=-1).sum(-1)
        H_H_f[e1, e2] = torch.sum(prob_trace_1_2 * x_traces_1_2).sum(dim=-1).sum(dim=-1) / (mask1.sum() * mask2.sum())

        

f_norm_env = H_H_f.diagonal()
shared_term = H_H_f.sum() / (num_envs ** 2)
individual_term = 2 * H_H_f.sum(dim=1) / num_envs
sum_h_minus_h_bar_sq = torch.sum(f_norm_env + shared_term - individual_term) / num_envs
sum_h_minus_h_bar_sq /= (d * num_classes)


In [1146]:
H_H_f

tensor([[346.3997,  30.1349],
        [ 30.1349, 271.5005]])

In [1147]:
f_norm_env, sum_h_minus_h_bar_sq2, HHf = hessian_pen(logits, x, envs, num_envs)

In [1148]:
sum_h_minus_h_bar_sq2

tensor(0.3485)

In [1149]:
X_outer1.shape

torch.Size([28, 200, 200])

In [1150]:
def gradient(x, logits, y, envs):
    # Ensure logits are in proper shape
    
    p = F.softmax(logits, dim=-1)
    # Generate one-hot encoding for y
    y_onehot = torch.zeros_like(p)

    y_onehot.scatter_(1, y.long().unsqueeze(-1), 1)

    # multiclasses
    grad_w = torch.matmul((p - y_onehot).T, x) / x.size(0)

    dC = grad_w.shape[0] * grad_w.shape[1]
    # grad_w /= (dC) ** 0.25
    # grad_w /= grad_w.shape[1] ** 0.5
    return grad_w

In [1151]:
sum_h_minus_h_bar_sq

tensor(0.3485)

In [1152]:
HHf

tensor([[346.3997,  30.1349],
        [ 30.1349, 271.5005]])

In [1153]:
grad_w1 = gradient(x, logits, y, envs)

In [1067]:
grad_w1.norm(p='fro') ** 2

tensor(0.0406)

In [1062]:
grad_w1.norm(p='fro') ** 2

tensor(0.0207)

In [1026]:
grad_w1.norm(p='fro') ** 2

tensor(0.0207)

In [858]:
y_onehot

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [859]:
y_onehot.scatter_(1, y.long().unsqueeze(-1), 1)

tensor([[0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]])