https://github.com/tcwangshiqi-columbia/symbolic_interval

In [11]:
import numpy as np
import torch
import torch.nn as nn

In [12]:
class Interval():
    '''Naive interval class
    Naive interval propagation is low-cost (only around two times slower 
    than regular NN propagation). However, the output range provided is 
    loose. This is because the dependency of inputs are ignored.
    See ReluVal https://arxiv.org/abs/1804.10829 for more details of
    the tradeoff.
    Naive interval propagation are used for many existing training
    schemes:
    (1) DiffAi: http://proceedings.mlr.press/v80/mirman18b/mirman18b.pdf
    (2) IBP: https://arxiv.org/pdf/1810.12715.pdf
    These training schemes are fast but the robustness of trained models
    suffers from the loose estimations of naive interval propagation.

    Args:
        lower: numpy matrix of the lower bound for each layer nodes
        upper: numpy matrix of the upper bound for each layer nodes
        lower and upper should have the same shape of input for 
        each layer
        no upper value should be less than corresponding lower value
    * :attr:`l` and `u` keeps the upper and lower values of the
      interval. Naive interval propagation using them to propagate.
    * :attr:`c` and `e` means the center point and the error range 
      of the interval. Symbolic interval propagation using to propagate
      since it can keep the dependency more efficiently. 
    * :attr:`mask` is used to keep the estimation information for each
      hidden node. It has the same shape of the ReLU layer input. 
      for each hidden node, before going through ReLU, let [l,u] denote
      a ReLU's input range. It saves the value u/(u-l), which is the
      slope of estimated output dependency. 0 means, given the input
      range, this ReLU's input will always be negative and the output 
      is always 0. 1 indicates, it always stays positive and the
      output will not change. Otherwise, this node is estimated during 
      interval propagation and will introduce overestimation error. 
    '''
    def __init__(self, lower, upper, use_cuda=False):
        if(not isinstance(self, Inverse_interval)):
            assert not ((upper-lower)<0).any(), "upper less than lower"
        self.l = lower
        self.u = upper
        self.c = (lower+upper)/2
        self.e = (upper-lower)/2
        self.mask = []
        self.use_cuda = use_cuda


    def update_lu(self, lower, upper):
        '''Update this interval with new lower and upper numpy matrix
        Args:
            lower: numpy matrix of the lower bound for each layer nodes
            upper: numpy matrix of the upper bound for each layer nodes
        '''
        if(not isinstance(self, Inverse_interval)):
            assert not ((upper-lower)<0).any(), "upper less than lower"
        self.l = lower
        self.u = upper
        self.c = (lower+upper)/2
        self.e = (upper-lower)/2


    def update_ce(self, center, error):
        '''Update this interval with new error and center numpy matrix
        Args:
            lower: numpy matrix of the lower bound for each layer nodes
            upper: numpy matrix of the upper bound for each layer nodes
        '''
        if(not isinstance(self, Inverse_interval)):
            assert not (error<0).any(), "upper less than lower"
        self.c = center
        self.e = error
        self.u = self.c+self.e
        self.l = self.c-self.e


    def __str__(self):
        '''Print function
        '''
        string = "interval shape:"+str(self.c.shape)
        string += "\nlower:"+str(self.l)
        string += "\nupper:"+str(self.u)
        return string

    def worst_case(self, y, output_size):
        '''Calculate the wrost case of the analyzed output ranges.
        In details, it returns the upper bound of other label minus 
        the lower bound of the target label. If the returned value is 
        less than 0, it means the worst case provided by interval
        analysis will never be larger than the target label y's. 
        '''
        assert y.shape[0] == self.l.shape[0] == self.u.shape[0],\
                "wrong input shape"
        
        for i in range(y.shape[0]):
            t = self.l[i, y[i]]
            self.u[i] = self.u[i]-t
            self.u[i, y[i]] = 0.0
        return self.u

In [13]:
class Interval_network(nn.Module):
    '''Convert a nn.Sequential model to a network support symbolic
    interval propagations/naive interval propagations.
    '''
    def __init__(self, model, c):
        nn.Module.__init__(self)

        self.net = []
        first_layer = True
        last_layer = False

        for layer in model:
            if(isinstance(layer, nn.Linear)):
                if layer == model[-1]:
                    last_layer = True
                if last_layer and c is not None:
                    wc_matrix = c
                else:
                    wc_matrix = None
                self.net.append(Interval_Dense(layer, first_layer, wc_matrix=wc_matrix))
                first_layer = False
            if(isinstance(layer, nn.ReLU)):
                self.net.append(Interval_ReLU(layer))
            if(isinstance(layer, nn.Conv2d)):
                self.net.append(Interval_Conv2d(layer, first_layer))
                first_layer = False
            if 'Flatten' in (str(layer.__class__.__name__)): 
                self.net.append(Interval_Flatten())
            if 'Vlayer' in (str(layer.__class__.__name__)): 
                self.net.append(Interval_Vlayer(layer))
            if 'bn' in (str(layer.__class__.__name__)):
                self.net.append(Interval_BN(layer))
        self.net = nn.Sequential(*self.net)
    '''Forward intervals for each layer.
    * :attr:`ix` is the input fore each layer. If ix is a naive
    interval, it will propagate naively. If ix is a symbolic 
    interval, it will propagate symbolicly.
    '''
    def forward(self, ix):
        return self.net(ix)
        '''
        for i, layer in enumerate(self.net):
            ix = layer(ix)
        return ix
        '''

In [14]:
class Interval_Dense(nn.Module):
    def __init__(self, layer, first_layer=False, wc_matrix=None):
        nn.Module.__init__(self)
        self.layer = layer
        self.first_layer = first_layer
        self.wc_matrix = wc_matrix

    def forward(self, ix):
        assert isinstance(ix, Interval), "Not Interval instance"

        c = ix.c
        e = ix.e
        if self.wc_matrix is None:
            c = F.linear(c, self.layer.weight, bias=self.layer.bias)
            e = F.linear(e, self.layer.weight.abs())
        else:
            weight = self.wc_matrix.matmul(self.layer.weight)
            bias = self.wc_matrix.matmul(self.layer.bias)
            
            c = weight.matmul(c.unsqueeze(-1)) + bias.unsqueeze(-1)
            e = weight.abs().matmul(e.unsqueeze(-1))

            c, e = c.squeeze(-1), e.squeeze(-1)

        #print(c.shape, e.shape)
        #print("naive e", e)
        #print("naive c", c)
        ix.update_lu(c-e, c+e)
        return ix

In [15]:
class Interval_Conv2d(nn.Module):
    def __init__(self, layer, first_layer=False):
        nn.Module.__init__(self)
        self.layer = layer
        self.first_layer = first_layer
        #print ("conv2d:", self.layer.weight.shape)

    def forward(self, ix):
        assert isinstance(ix, Interval), "Not Interval instance"

        c = ix.c
        e = ix.e
        c = F.conv2d(c, self.layer.weight, 
                       stride=self.layer.stride,
                       padding=self.layer.padding, 
                       bias=self.layer.bias)
        e = F.conv2d(e, self.layer.weight.abs(), 
                       stride=self.layer.stride,
                       padding=self.layer.padding)
        ix.update_lu(c-e, c+e)
        return ix

In [16]:
class Interval_BN(nn.Module):
    def __init__(self, layer, first_layer=False):
        nn.Module.__init__(self)
        self.layer = layer
        self.first_layer = first_layer

    def forward(self, ix):
        assert isinstance(ix, Interval), "Not Interval instance"

        shape = ix.u.shape
        tmax = torch.where(ix.u>-ix.l, ix.u, ix.l).view(ix.batch_size, -1)
        #tmax = ix.c.view(ix.batch_size, -1)
        mean = tmax.mean(dim=0, keepdim=True)
        # print(mean.shape, tmax.shape, ix.u.shape)
        
        sigma = torch.norm(tmax-mean, dim=0, keepdim=True)
        #sigma = sigma*sigma

        # if self.layer.mean is None:
        #     self.layer.mean = mean
        #     self.layer.sigma = sigma
        # else:
        
        # self.layer.mean = self.layer.mean * (1-self.layer.momentum) + self.layer.momentum * mean
        # self.layer.sigma = self.layer.sigma * (1-self.layer.momentum) + self.layer.momentum * sigma

        self.layer.mean = mean
        self.layer.sigma = sigma
        #print(mean, sigma)

        ix.u = (ix.u.view(ix.batch_size, -1)-mean)/sigma
        ix.l = (ix.l.view(ix.batch_size, -1)-mean)/sigma
        ix.u, ix.l = ix.u.view(shape), ix.l.view(shape)
        return ix

In [17]:
class Interval_ReLU(nn.Module):
    def __init__(self, layer):
        nn.Module.__init__(self)
        self.layer = layer

    def forward(self, ix):
        #print(ix.u)
        #print(ix.l)
        assert isinstance(ix, Interval), "Not Interval instance"

        '''
        lower = ix.l.clamp(max=0)
        upper = ix.u.clamp(min=0)
        upper = torch.max(upper, lower + 1e-8)
        mask = upper / (upper - lower)
        ix.mask.append(mask)
        '''
        ix.update_lu(F.relu(ix.l), F.relu(ix.u))
        return ix

In [18]:
class Interval_Flatten(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)

    def forward(self, ix):
        assert isinstance(ix, Interval), "Not Interval instance"
        ix.update_lu(ix.l.view(ix.l.size(0), -1),\
            ix.u.view(ix.u.size(0), -1))
        return ix

In [21]:
class Interval_Bound(nn.Module):
    def __init__(self, net, epsilon, method="sym",\
                    proj=None, use_cuda=True, norm="linf", worst_case=True):
        nn.Module.__init__(self)
        self.net = net
        self.epsilon = epsilon
        self.use_cuda = use_cuda
        self.proj = proj
        if(proj is not None):
            assert proj>0, "project dimension has to be larger than 0,"\
                    " please use naive bound propagation (proj=0)!"
            assert (isinstance(proj, int)), "project dimension has to"\
                    " be integer!"

        assert method in ["sym", "naive", "inverse", "center_sym", "new", "gen", "mix"],\
                "No such interval methods!"
        self.method = method
        self.norm = norm
        #assert self.norm in ["linf", "l2", "l1"], "norm" + norm + "not supported"

        self.worst_case = worst_case
            
    def forward(self, X, y):
        
        out_features = self.net[-1].out_features

        if self.worst_case:
            c = torch.eye(out_features).type_as(X)[y].unsqueeze(1) -\
                    torch.eye(out_features).type_as(X).unsqueeze(0)
        else:
            c = None

        # Transfer original model to interval models
        inet = Interval_network(self.net, c)
        
        minimum = X.min().item()
        maximum = X.max().item()

        # Create symbolic inteval classes from X
        assert self.method == "naive", "Not naive method"
        ix = Interval(torch.clamp(X-self.epsilon, minimum, maximum),\
                    torch.clamp(X+self.epsilon, minimum, maximum),\
                    elf.use_cuda\
                )
        # Propagate symbolic interval through interval networks
        ix = inet(ix)
        # print(ix.u)
        # print(ix.l)
        return -ix.l

In [22]:
def naive_interval_analyze(net, epsilon, X, y,\
                    use_cuda=True, parallel=False, norm="linf"):

    # Transfer original model to interval models

    if(parallel):
        wc = nn.DataParallel(Interval_Bound(net, epsilon,
                method="naive", use_cuda=use_cuda, norm=norm))(X, y)
    else:
        wc = Interval_Bound(net, epsilon, method="naive",\
                        use_cuda=use_cuda, norm=norm)(X, y)

    iloss = nn.CrossEntropyLoss()(wc, y)
    ierr = (wc.max(1)[1]!=y).type(torch.Tensor)
    ierr = ierr.sum().item()/X.shape[0]

    return iloss, ierr