In [1]:
from bounds import *
import torch
import numpy as np

# Spectral norm estimation on dense matrix

This code is related to dense spectral norm estimation, see Section 5.5 in paper. 

In [2]:
n, m = 500, 400
with_cuda = False
G = torch.randn(n, m).double().cuda()
sigma_1_reference = torch.linalg.matrix_norm(G, ord=2).item()
print("dense sigma_1_reference", sigma_1_reference)

dense sigma_1_reference 42.809956870632924


# Test Power iteration on dense matrix

In [3]:
n_iter_pi = 1000
nb_reps = 10
burn = 10
sigmas_1_pi, times = [], []
for _ in range(nb_reps + burn):
    sigma_1_pi, time = estimate_dense(G, 
                                      n_iter=n_iter_pi, 
                                      name_func="pi", 
                                      return_time=True)
    sigmas_1_pi.append(sigma_1_pi.item())
    times.append(time)
print("Diff Power iteration", np.mean(sigmas_1_pi[burn:]) - sigma_1_reference, 
      "Mean time", np.mean(times[burn:]))

Diff Power iteration -2.7071678232459817e-12 Mean time 0.0829003095626831


# Test Gram iteration on dense matrix

In [4]:
n_iter_gram = 15
nb_reps = 10
burn = 10
sigmas_1_gram, times = [], []
for _ in range(nb_reps + burn):
    sigma_1_gram, time = estimate_dense(G, 
                                        n_iter=n_iter_gram, 
                                        name_func="ours", 
                                        return_time=True)
    sigmas_1_gram.append(sigma_1_gram.item())
    times.append(time)
print("Diff Gram iteration", np.mean(sigmas_1_gram[burn:]) - sigma_1_reference, 
      "Mean time", np.mean(times[burn:]))

Diff Gram iteration -2.6929569685307797e-12 Mean time 0.0015299320220947266


# Spectral norm estimation on convolutional layer

This code is related to convolutional layer spectral norm estimation, see Section 5.2 in paper.# Define convolution kernel

In [5]:
cout = 64
cin = 64
input_size_n = 12
kernel_size = 5


kernel = torch.randn(cout, cin, kernel_size, kernel_size).cuda()

sigma_1_reference_sedghi2019 = estimate(kernel, 
                                        n=input_size_n, 
                                        name_func="sedghi2019").item()
print("kernel sigma_1_reference_sedghi2019", sigma_1_reference_sedghi2019)

n_iter_ryu2019_ref = 100
sigma_1_reference_ryu2019 = estimate(kernel, 
                                     n=input_size_n, 
                                     n_iter=n_iter_ryu2019_ref, 
                                     name_func="ryu2019").item()
print("kernel sigma_1_reference_ryu2019", sigma_1_reference_ryu2019)

kernel sigma_1_reference_sedghi2019 81.72451782226562
kernel sigma_1_reference_ryu2019 77.01795959472656


# Test Araujo2021 on convolutional layer

In [6]:
nb_samples_araujo2021 = 50
nb_reps = 10
burn = 10
sigmas_1_araujo2021, times = [], []
for _ in range(nb_reps + burn):
    sigma_1_araujo2021, time = estimate(kernel, 
                                  n=input_size_n, 
                                  n_iter=nb_samples_araujo2021, 
                                  name_func="araujo2021", 
                                  return_time=True)
    sigmas_1_araujo2021.append(sigma_1_araujo2021.item())
    times.append(time)
print("Diff Sedghi2019 with Araujo2021 conv ", np.mean(sigmas_1_araujo2021[burn:]) - sigma_1_reference_sedghi2019,
      "\nDiff Ryu2019 with Araujo2021 conv", np.mean(sigmas_1_araujo2021[burn:]) - sigma_1_reference_ryu2019, 
      "\nMean time", np.mean(times[burn:]))

Diff Sedghi2019 with Araujo2021 conv  272.14373779296875 
Diff Ryu2019 with Araujo2021 conv 276.8502960205078 
Mean time 0.0016119718551635743


# Test Singla2021 on convolutional layer

In [7]:
n_iter_singla2021 = 50
nb_reps = 10
burn = 10
sigmas_1_singla2021, times = [], []
for _ in range(nb_reps + burn):
    sigma_1_singla2021, time = estimate(kernel, 
                                  n=input_size_n, 
                                  n_iter=n_iter_singla2021, 
                                  name_func="singla2021", 
                                  return_time=True)
    sigmas_1_singla2021.append(sigma_1_singla2021.item())
    times.append(time)
print("Diff Sedghi2019 with Singla2021conv", np.mean(sigmas_1_singla2021[burn:]) - sigma_1_reference_sedghi2019,
      "\nDiff Ryu2019 with Singla2021 conv", np.mean(sigmas_1_singla2021[burn:]) - sigma_1_reference_ryu2019, 
      "\nMean time", np.mean(times[burn:]))

Diff Sedghi2019 with Singla2021conv 92.49748229980469 
Diff Ryu2019 with Singla2021 conv 97.20404052734375 
Mean time 0.027846980094909667


# Test Gram iteration on convolution layer

In [8]:
n_iter_gram = 5
nb_reps = 10
burn = 10
sigmas_1_gram, times = [], []
for _ in range(nb_reps + burn):
    sigma_1_gram, time = estimate(kernel, 
                                  n=input_size_n, 
                                  n_iter=n_iter_gram, 
                                  name_func="ours", 
                                  return_time=True)
    sigmas_1_gram.append(sigma_1_gram.item())
    times.append(time)
print("Diff Sedghi2019 with Gram iteration conv", np.mean(sigmas_1_gram[burn:]) - sigma_1_reference_sedghi2019,
      "\nDiff Ryu2019 with Gram iteration conv", np.mean(sigmas_1_gram[burn:]) - sigma_1_reference_ryu2019, 
      "\nMean time", np.mean(times[burn:]))

Diff Sedghi2019 with Gram iteration conv 0.0018463134765625 
Diff Ryu2019 with Gram iteration conv 4.708404541015625 
Mean time 0.0010091543197631836


# Compute spectral norm of convolutional layers of ResNet18

This code is related to Section 5.3 in paper.

In [9]:
from torchvision import models

n_iter_name = {
               "ours":7,
               "singla2021":50,
               "ryu2019" : 100,
               "araujo2021" : 20,
               "sedghi2019":None}
func_names = [
              "ryu2019",
              #"sedghi2019", # commented because it takes a while
              "araujo2021",
              "singla2021",
              "ours",
]
model_resnet_18 = models.resnet18(pretrained=True).cuda().eval()



inp_shape = (224, 224)
lip_tot = {name : 1 for name in func_names}
times_tot = {name : 0 for name in func_names}

with torch.no_grad():
    for name, module in model_resnet_18.named_modules():
        print()
        is_downsample = name.endswith("downsample")
        is_regular_conv = "conv" in name
        is_max_pool = "MaxPool2d" in module.__class__.__name__
        if is_max_pool:
            stride = module.stride
            inp_shape = (inp_shape[0] // stride, inp_shape[1] // stride)
        if is_downsample:
            # dowsampling layer in residual connection
            conv, bn = module[0], module[1]
            lip_bn = (bn.weight.detach() / bn.running_var).max().item()
            inp_shape = (inp_shape[0] * stride, inp_shape[1] * stride)
        elif is_regular_conv:
            conv = module
            stride = conv.stride[0]
            lip_bn = 1.0
        if is_downsample or is_regular_conv:
            param = conv.weight.clone().detach()
            out_channels, in_channels, H, W = param.shape
            for name in func_names:
                bound, curr_time = estimate(param,
                                            inp_shape[0],
                                            n_iter_name[name],
                                            name,
                                            return_time=True)
                bound = bound.item()
                print(name, "conv weight dim", param.shape, "n", inp_shape[0], "bound", bound, "time", curr_time)
                if is_downsample:
                    lip_tot[name] += bound* lip_bn
                else:
                    # Lipschiz of batch norm cancels in ratio
                    lip_tot[name] *= bound
                times_tot[name] += curr_time
                print("Total Lipschitz bound", lip_tot[name], "total time", times_tot[name], "\n")
            inp_shape = (inp_shape[0] // stride, inp_shape[1] // stride)






ryu2019 conv weight dim torch.Size([64, 3, 7, 7]) n 224 bound 15.869194984436035 time 0.26656007766723633
Total Lipschitz bound 15.869194984436035 total time 0.26656007766723633 

araujo2021 conv weight dim torch.Size([64, 3, 7, 7]) n 224 bound 31.370620727539062 time 0.003934144973754883
Total Lipschitz bound 31.370620727539062 total time 0.003934144973754883 

singla2021 conv weight dim torch.Size([64, 3, 7, 7]) n 224 bound 28.89466667175293 time 0.03105306625366211
Total Lipschitz bound 28.89466667175293 total time 0.03105306625366211 

ours conv weight dim torch.Size([64, 3, 7, 7]) n 224 bound 15.916428565979004 time 0.04238271713256836
Total Lipschitz bound 15.916428565979004 total time 0.04238271713256836 







ryu2019 conv weight dim torch.Size([64, 64, 3, 3]) n 56 bound 5.979285717010498 time 0.3389859199523926
Total Lipschitz bound 94.88645091089302 total time 0.6055459976196289 

araujo2021 conv weight dim torch.Size([64, 64, 3, 3]) n 56 bound 15.715842247009277 time 0.00

ryu2019 conv weight dim torch.Size([256, 128, 1, 1]) n 14 bound 1.215047001838684 time 0.07498335838317871
Total Lipschitz bound 360117990.75099486 total time 1.5774972438812256 

araujo2021 conv weight dim torch.Size([256, 128, 1, 1]) n 14 bound 5.973367214202881 time 0.000865936279296875
Total Lipschitz bound 421865104751764.56 total time 0.021115541458129883 

singla2021 conv weight dim torch.Size([256, 128, 1, 1]) n 14 bound 1.215043544769287 time 0.029513835906982422
Total Lipschitz bound 11833540310.659184 total time 0.3956260681152344 

ours conv weight dim torch.Size([256, 128, 1, 1]) n 14 bound 1.215043544769287 time 0.0018472671508789062
Total Lipschitz bound 384220381.6761297 total time 0.1572742462158203 





ryu2019 conv weight dim torch.Size([256, 256, 3, 3]) n 14 bound 6.226568222045898 time 0.030767202377319336
Total Lipschitz bound 2242299237.3971634 total time 1.608264446258545 

araujo2021 conv weight dim torch.Size([256, 256, 3, 3]) n 14 bound 30.153610229492188 ti

# Total Lipschitz ratio bound 

In [10]:
lip_tot_ref = lip_tot["ryu2019"]
for name in func_names:
    lip_tot[name] /= lip_tot_ref
print("Total Lipschitz ratio", lip_tot)
print("Total times", times_tot)

Total Lipschitz ratio {'ryu2019': 1.0, 'araujo2021': 30256475961.577263, 'singla2021': 86.50689854246077, 'ours': 1.4725238046216356}
Total times {'ryu2019': 1.9191737174987793, 'araujo2021': 0.03477597236633301, 'singla2021': 0.6056292057037354, 'ours': 0.17369508743286133}
