In [1]:
import torch
import numpy as np
import sys

sys.path.append("monotone_op_net")
import train
import splitting as sp

<h2> Generalization bound for FC monDEQ model

In [2]:
# Instantiate and load FC monDEQ model
m = 20
h = 40
model = train.SingleFcNet(sp.MONPeacemanRachford,
                        in_dim=28**2,
                        out_dim=40,
                        alpha=1.,
                        max_iter=300,
                        tol=1e-3,
                        m=m)
model.load_state_dict(torch.load("models/mon_h40_m20.0.pt", map_location=torch.device('cpu')))
model.eval()

SingleFcNet(
  (mon): MONPeacemanRachford(
    (linear_module): MONSingleFc(
      (U): Linear(in_features=784, out_features=40, bias=True)
      (A): Linear(in_features=40, out_features=40, bias=False)
      (B): Linear(in_features=40, out_features=40, bias=False)
    )
    (nonlin_module): MONReLU()
  )
  (Wout): Linear(in_features=40, out_features=10, bias=True)
)

In [7]:
def n(t):
    if isinstance(t, torch.Tensor):
        return float(t.cpu().detach().numpy())
    else:
        return t

def bound_term_2(h, c, beta, B, m, W_norm_sum, M, gamma, delta):
    cp, cm = c + 1, c - 1
    c2 = gamma / 4 / beta
    inner = 2 * beta ** 3 * cp ** 3 * (1 + c2) * (B + 1) + 2 * m * beta * cp * (B + 1) * cm ** 2 + m ** 2 * cm ** 3
    num = 16 * h * np.log(24 * h) * inner ** 2
    num_1 = num * W_norm_sum / (gamma ** 2 * m ** 4 * cm ** 6 * (M - 1))
    num_2 = np.log(3 * M ** 1.5 / c / delta)
    return 4 * torch.sqrt(num_1 + num_2 / (M - 1))


def monW(model):
    lin_mod = model.mon.linear_module
    m = lin_mod.m
    A = lin_mod.A.weight
    B = lin_mod.B.weight
    W = (1 - m) * cuda(torch.eye(A.shape[1])) - A.T @ A + B - B.T
    return W

def W_norm_sum(model, W_init):
    A = model.mon.linear_module.A.weight
    B = model.mon.linear_module.B.weight
    U = model.mon.linear_module.U.weight - W_init['U']
    b = model.mon.linear_module.U.bias - W_init['b']
    Wo = model.Wout.weight - W_init['Wo']
    bo = model.Wout.bias - W_init['bo']
    print("F_norms\n\t", [n(torch.norm(x) ** 2) for x in [A, B, U, b, Wo, bo]])
    return sum([torch.norm(x) ** 2 for x in [A, B, U, b, Wo, bo]])

def beta(model):
    lin_mod = model.mon.linear_module
    U = lin_mod.U.weight
    b = lin_mod.U.bias
    Wo = model.Wout.weight
    bo = model.Wout.bias
    A = model.mon.linear_module.A.weight
    B = model.mon.linear_module.B.weight
    try:
        norms = [torch.norm(b),
                 # torch.norm(bo),
                 torch.max(torch.svd(U)[1]),
                 torch.max(torch.svd(Wo)[1]),
                 torch.max(torch.svd(A)[1]),
        ]
        print("l2_norms\n\t", [n(norm) for norm in norms])
        return max(norms)
    except:
        return 0

In [12]:
M = 4096
gamma=10
delta=0.1
c=.0001
B=50
W_init={k:0 for k in ['W','Wo','b','bo','U']}
lin_mod = model.mon.linear_module
W_norm_sum_val = W_norm_sum(model, W_init)
print("W_norm_sum:", n(W_norm_sum_val))
beta_val = beta(model)
print("beta:", n(beta_val))
b = bound_term_2(h, c, beta_val, B, m, W_norm_sum_val, M, gamma, delta)
print("bound term 2:", n(b))            

F_norms
	 [55.68635177612305, 47.7231330871582, 1747.61279296875, 0.3111502230167389, 110.89112091064453, 0.13506852090358734]
W_norm_sum: 1962.3594970703125
l2_norms
	 [0.5578083992004395, 15.990426063537598, 4.798660755157471, 5.0687079429626465]
beta: 15.990426063537598
bound term 2: 23623.228515625
