In [7]:
import torch
import torch.nn as nn
import numpy as np
import sys

sys.path.append("monotone_op_net")
import train
import splitting as sp
import mon

<h2> Generalization bound for FC monDEQ model

In [8]:
# Instantiate and load FC monDEQ model
m = 20
h = 40
model = train.SingleFcNet(sp.MONPeacemanRachford,
                        in_dim=28**2,
                        out_dim=40,
                        alpha=1.,
                        max_iter=300,
                        tol=1e-3,
                        m=m)
model.load_state_dict(torch.load('data/gen_bounds/9-17/mon_m20_d0_1.pt', map_location=torch.device('cpu')))
model.eval()

FileNotFoundError: [Errno 2] No such file or directory: 'data/gen_bounds/9-17/mon_m20_d0_1.pt'

In [35]:
def n(t):
    if isinstance(t, torch.Tensor):
        return float(t.cpu().detach().numpy())
    else:
        return t

def bound_term_2(h, c, beta, B, m, W_norm_sum, M, gamma, delta):
    cp, cm = c + 1, c - 1
    c2 = gamma / 4 / beta
    inner = 2 * beta ** 3 * cp ** 3 * (1 + c2) * (B + 1) + 2 * m * beta * cp * (B + 1) * cm ** 2 + m ** 2 * cm ** 3
    num = 16 * h * np.log(24 * h) * inner ** 2
    num_1 = num * W_norm_sum / (gamma ** 2 * m ** 4 * cm ** 6 * (M - 1))
    num_2 = np.log(3 * M ** 1.5 / c / delta)
    return 4 * torch.sqrt(num_1 + num_2 / (M - 1))


def monW(model):
    lin_mod = model.mon.linear_module
    m = lin_mod.m
    A = lin_mod.A0.weight
    B = lin_mod.B0.weight
    W = (1 - m) * cuda(torch.eye(A.shape[1])) - A.T @ A + B - B.T
    return W

def W_norm_sum(model, W_init):
    A = model.mon.linear_module.A0.weight
    B = model.mon.linear_module.B0.weight
    U = model.mon.linear_module.U.weight - W_init['U']
    b = model.mon.linear_module.U.bias - W_init['b']
    Wo = model.Wout.weight - W_init['Wo']
    bo = model.Wout.bias - W_init['bo']
    print("F_norms\n\t", [n(torch.norm(x) ** 2) for x in [A, B, U, b, Wo, bo]])
    return sum([torch.norm(x) ** 2 for x in [A, B, U, b, Wo, bo]])

def beta(model):
    lin_mod = model.mon.linear_module
    U = lin_mod.U.weight
    b = lin_mod.U.bias
    Wo = model.Wout.weight
    bo = model.Wout.bias
    A = model.mon.linear_module.A0.weight
    B = model.mon.linear_module.B0.weight
    try:
        norms = [torch.norm(b),
                 # torch.norm(bo),
                 torch.max(torch.svd(U)[1]),
                 torch.max(torch.svd(Wo)[1]),
                 torch.max(torch.svd(A)[1]),
        ]
        print("l2_norms\n\t", [n(norm) for norm in norms])
        return max(norms)
    except:
        return 0

In [36]:
M = 4096
gamma=10
delta=0.1
c=.0001
B=50
W_init={k:0 for k in ['W','Wo','b','bo','U']}
lin_mod = model.mon.linear_module
W_norm_sum_val = W_norm_sum(model, W_init)
print("W_norm_sum", W_norm_sum_val)
beta_val = beta(model)
print("beta", beta_val)
b = bound_term_2(h, c, beta_val, B, m, W_norm_sum_val, M, gamma, delta)
print(b)            

F_norms
	 [8.780768394470215, 39.16813278198242, 1560.82177734375, 0.0, 1599.3623046875, 0.0]
W_norm_sum tensor(3208.1328, grad_fn=<AddBackward0>)
l2_norms
	 [0.0, 14.334519386291504, 17.337421417236328, 1.0733709335327148]
beta tensor(17.3374, grad_fn=<MaxBackward1>)
tensor(37761.9141, grad_fn=<MulBackward0>)
