In [None]:
import torch.nn as nn
import numpy
import torch.nn.functional as F
import torch
from torch.autograd import Variable


# Util

In [None]:
def sto_quant(tensor):
    # Stochastic Quantization Function
    # Adds 1 to each element in tensor, divides by 2, adds a random value, clamps, rounds, multiplies by 2, and subtracts 1
    return tensor.add(1.).div(2.).add(torch.rand(tensor.size()).cuda().add(-0.5)).clamp(0.,1.).round().mul(2.).add(-1.)

class BinOp():
    def __init__(self, model, mode='allbin'):
        # Initializer for binary operations on the model
        
        # Count the number of Conv2d layers in the model
        count_Conv2d = 0
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                count_Conv2d = count_Conv2d + 1
        
        # Ensure the mode is either 'allbin' or 'nin'
        assert mode in ['allbin','nin'], 'No such a mode!'
        
        # Set the range of layers to be binarized based on the mode
        if mode == 'allbin':
            start_range = 0
            end_range = count_Conv2d-1
        elif mode == 'nin':
            start_range = 1
            end_range = count_Conv2d-2
        
        # Set the range of layers to be binarized
        self.bin_range = numpy.linspace(start_range, end_range, end_range-start_range+1).astype('int').tolist()
        
        # Initialize storage for parameters and target modules
        self.num_of_params = len(self.bin_range)
        self.saved_params = []
        self.target_params = []
        self.target_modules = []
        index = -1
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                index = index + 1
                if index in self.bin_range:
                    tmp = m.weight.data.clone()
                    self.saved_params.append(tmp)
                    self.target_modules.append(m.weight)

    def binarization(self, quant_mode='det'):
        # Binarization sequence
        self.meancenterConvParams()
        self.clampConvParams()
        self.save_params()
        self.binarizeConvParams(quant_mode)

    def meancenterConvParams(self):
        # Mean Centering of Convolution Parameters
        for index in range(self.num_of_params):
            s = self.target_modules[index].data.size()
            negMean = self.target_modules[index].data.mean(1, keepdim=True).mul(-1).expand_as(self.target_modules[index].data)
            self.target_modules[index].data = self.target_modules[index].data.add(negMean)

    def clampConvParams(self):
        # Clamping Convolution Parameters between -1 and 1
        for index in range(self.num_of_params):
            self.target_modules[index].data = self.target_modules[index].data.clamp(min=-1.0, max=1.0)

    def save_params(self):
        # Saving the current parameters
        for index in range(self.num_of_params):
            self.saved_params[index].copy_(self.target_modules[index].data)

    def binarizeConvParams(self, quant_mode):
        # Binarize Convolution Parameters
        assert quant_mode in ['det', 'sto'], 'No such a quant_mode'
        for index in range(self.num_of_params):
            n = self.target_modules[index].data[0].nelement()
            s = self.target_modules[index].data.size()
            m = self.target_modules[index].data.norm(1, 3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True).div(n)
            if quant_mode == 'det':
                self.target_modules[index].data = self.target_modules[index].data.sign().mul(m.expand(s))
            elif quant_mode == 'sto':
                self.target_modules[index].data = sto_quant(self.target_modules[index].data).mul(m.expand(s))

    def restore(self):
        # Restore the full precision values back to the weights
        for index in range(self.num_of_params):
            self.target_modules[index].data.copy_(self.saved_params[index])

    def updateBinaryGradWeight(self):
        # Update Binary Gradient Weight
        for index in range(self.num_of_params):
            weight = self.target_modules[index].data
            n = weight[0].nelement()
            s = weight.size()
            m = weight.norm(1, 3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s)
            m[weight.lt(-1.0)] = 0
            m[weight.gt(1.0)] = 0
            m = m.mul(self.target_modules[index].grad.data)
            m_add = weight.sign().mul(self.target_modules[index].grad.data)
            m_add = m_add.sum(3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s)
            m_add = m_add.mul(weight.sign())
            self.target_modules[index].grad.data = m_add.add(m).mul(1.0-1.0/s[1]).mul(n)

class WeightedLoss(nn.Module):
    def __init__(self, aggregate='mean'):
        # Initializer for Weighted Loss
        super(WeightedLoss, self).__init__()
        assert aggregate in ['normal_ce_mean', 's_ce_mean', 'sc_ce_mean'], 'No such a mode'
        self.aggregate = aggregate

    def forward(self, input, target, weights=None):
        # Compute the loss based on the aggregation mode
        if self.aggregate == 'normal_ce_mean':
            return F.cross_entropy(input, target)
        elif self.aggregate == 's_ce_mean':
            sep_loss = F.cross_entropy(input, target, reduce=False)
            weights.squeeze_()
            assert sep_loss.size() == weights.size(), 'Required size: %r, but got: %r' % (str(sep_loss.size()),str(weights.size()))
            return (sep_loss*Variable(weights.cuda().float())).mean()
        elif self.aggregate == 'sc_ce_mean':
            batch_size = target.data.nelement()
            oned_weights = weights[:,target.data.cpu().numpy()].diag()
            sep_loss = F.cross_entropy(input, target, reduce=False)
            assert sep_loss.size() == oned_weights.size(), 'Required size: %r, but got: %r' % (str(sep_loss.size()),str(oned_weights.size()))
            return (sep_loss*Variable(oned_weights.cuda().float())).mean()


# Networks - NIN

In [None]:
class BinActive(torch.autograd.Function):
    '''
    Binarize the input activations and calculate the mean across channel dimension.
    '''
    def forward(self, input):
        self.save_for_backward(input)
        size = input.size()
        mean = torch.mean(input.abs(), 1, keepdim=True)
        input = input.sign()
        return input, mean

    def backward(self, grad_output, grad_output_mean):
        input, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input

class BinConv2d(nn.Module):
    '''
    Conv layer with vinarized weights and input
    '''
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, dropout=0):
        super(BinConv2d, self).__init__()
        self.layer_type = 'BinConv2d'
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dropout_ratio = dropout

        self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
        if dropout!=0:
            self.dropout = nn.Dropout(dropout)
        self.conv = nn.Conv2d(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.bn(x)
        x, mean = BinActive()(x)
        if self.dropout_ratio!=0:
            x = self.dropout(x)
        x = self.conv(x)
        x = self.relu(x)
        return x


class RealConv2d(nn.Module):
    '''
    Float conv layer with the same architecture with class::BinConv2d
    '''
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, dropout=0):
        super(RealConv2d, self).__init__()
        self.layer_type = 'RealConv2d'
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dropout_ratio = dropout

        self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
        if dropout!=0:
            self.dropout = nn.Dropout(dropout)
        self.conv = nn.Conv2d(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.bn(x)
        #x, mean = BinActive()(x)
        if self.dropout_ratio!=0:
            x = self.dropout(x)
        x = self.conv(x)
        x = self.relu(x)
        return x



class Net(nn.Module):
    '''
    The original binarized XNOR-NIN model
    '''
    def __init__(self):
        super(Net, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 96, kernel_size=1, stride=1, padding=0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d( 96, 192, kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(192, 192, kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x


####################################################
#
# model variants
#
####################################################


class Net_Cut(nn.Module):
    '''
    The 'narrower' XNOR-NIN model
    '''
    def __init__(self, cut_ratio=0.5):
        super(Net_Cut, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, int(192*cut_ratio), kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(int(192*cut_ratio), eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(int(192*cut_ratio), int(96*cut_ratio), kernel_size=1, stride=1, padding=0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(int(96*cut_ratio), int(192*cut_ratio), kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(int(192*cut_ratio), eps=1e-4, momentum=0.1, affine=False),
                nn.Conv2d(int(192*cut_ratio),  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x



class RealNet(nn.Module):
    '''
    The float NIN model
    '''
    def __init__(self):
        super(RealNet, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                RealConv2d(192, 96, kernel_size=1, stride=1, padding=0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                RealConv2d( 96, 192, kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                RealConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                RealConv2d(192, 192, kernel_size=3, stride=1, padding=1, dropout=0.5),
                RealConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x


class AllBinNet(nn.Module):
    '''
    binarize the last and first layers of the original XNOR-NIN model
    '''
    def __init__(self):
        super(AllBinNet, self).__init__()
        self.xnor = nn.Sequential(
                #nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                #nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                #nn.ReLU(inplace=True),
                BinConv2d(3, 192, kernel_size=5, stride=1, padding=2),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 96, kernel_size=1, stride=1, padding=0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d( 96, 192, kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(192, 192, kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                #nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                #nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                #nn.ReLU(inplace=True),
                BinConv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x



class NotAllBinNet(nn.Module):
    '''
    Only binarize the last layer of the original XNOR-NIN model
    '''
    def __init__(self):
        super(NotAllBinNet, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                #BinConv2d(3, 192, kernel_size=5, stride=1, padding=2),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 96, kernel_size=1, stride=1, padding=0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d( 96, 192, kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(192, 192, kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                #nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                #nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                #nn.ReLU(inplace=True),
                BinConv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x





class AllBinNet_Cut(nn.Module):
    def __init__(self, cut_ratio=0.5):
        super(AllBinNet_Cut, self).__init__()
        self.xnor = nn.Sequential(
                #nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                #nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                #nn.ReLU(inplace=True),
                BinConv2d(3, int(192*cut_ratio), kernel_size=5, stride=1, padding=2),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(int(192*cut_ratio), int(96*cut_ratio), kernel_size=1, stride=1, padding=0, dropout=0.0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d( int(96*cut_ratio), int(192*cut_ratio), kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0, dropout=0.0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0, dropout=0.0),
                #nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                #nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                #nn.ReLU(inplace=True),
                BinConv2d(int(192*cut_ratio),  10, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x


class NotAllBinNet_Cut(nn.Module):
    def __init__(self, cut_ratio=0.5):
        super(NotAllBinNet_Cut, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, int(192*cut_ratio), kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(int(192*cut_ratio), eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                #BinConv2d(3, int(192*cut_ratio), kernel_size=5, stride=1, padding=2),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(int(192*cut_ratio), int(96*cut_ratio), kernel_size=1, stride=1, padding=0, dropout=0.0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d( int(96*cut_ratio), int(192*cut_ratio), kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0, dropout=0.0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0, dropout=0.0),
                #nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                #nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                #nn.ReLU(inplace=True),
                BinConv2d(int(192*cut_ratio),  10, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x




class RealNet_Cut(nn.Module):
    def __init__(self, cut_ratio=0.5):
        super(RealNet_Cut, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, int(192*cut_ratio), kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(int(192*cut_ratio), eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                #BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                RealConv2d(int(192*cut_ratio), int(96*cut_ratio), kernel_size=1, stride=1, padding=0), # new by simon
                #BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                RealConv2d( int(96*cut_ratio), int(192*cut_ratio), kernel_size=5, stride=1, padding=2, dropout=0.5),
                #BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                RealConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                RealConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=3, stride=1, padding=1, dropout=0.5),
                RealConv2d(int(192*cut_ratio), int(192*cut_ratio), kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(int(192*cut_ratio), eps=1e-4, momentum=0.1, affine=False),
                nn.Conv2d(int(192*cut_ratio),  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x

In [None]:
import torch
from torch import nn

class BinActive(torch.autograd.Function):
    """
    Binarize the input activations and calculate the mean across channel dimension.
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        size = input.size()
        mean = torch.mean(input.abs(), 1, keepdim=True)
        input = input.sign()
        return input, mean

    @staticmethod
    def backward(ctx, grad_output, grad_output_mean):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input


class BaseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dropout=0):
        super(BaseConv2d, self).__init__()
        self.bn = nn.BatchNorm2d(in_channels, eps=1e-4, momentum=0.1, affine=True)
        self.dropout = nn.Dropout(dropout) if dropout else None
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        raise NotImplementedError


class BinConv2d(BaseConv2d):
    def forward(self, x):
        x = self.bn(x)
        x, mean = BinActive.apply(x)
        if self.dropout:
            x = self.dropout(x)
        x = nn.functional.conv2d(
            x, self.weight, self.bias, self.stride, self.padding
        )
        x = self.relu(x)
        return x


class RealConv2d(BaseConv2d):
    def __init__(self, *args, **kwargs):
        super(RealConv2d, self).__init__(*args, **kwargs)
        self.conv = nn.Conv2d(self.bn.num_features, self.out_channels, self.kernel_size, self.stride, self.padding)

    def forward(self, x):
        x = self.bn(x)
        if self.dropout:
            x = self.dropout(x)
        x = self.conv(x)
        x = self.relu(x)
        return x


class NetBase(nn.Module):
    def __init__(self, cut_ratio=1.0):
        super(NetBase, self).__init__()
        self.cut_ratio = cut_ratio

    def forward(self, x):
        raise NotImplementedError


class Net(NetBase):
    def __init__(self):
        super(Net, self).__init__()
        # ... (rest of your initialization logic)

    def forward(self, x):
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x

# ... (continue with other Net classes, following the same pattern as above)



In [1]:
import torch
import numpy as np

# Define the sto_quant function
def sto_quant(tensor):
    return tensor.add(1.).div(2.).add(torch.rand(tensor.size())\
        .add(-0.5)).clamp(0.,1.).round().mul(2.).add(-1.)

# Assume tensor is your input tensor and filter_tensor is your filter
tensor = torch.tensor([[0.9, -1.2, 0.6, 0.3],
                       [-0.4, 2.0, -0.8, 1.0],
                       [0.7, -1.5, 1.1, -0.6],
                       [0.2, 0.4, -0.9, 0.8]])

filter_tensor = torch.tensor([[0.1, -0.2],
                              [-0.3, 0.4]])

# Applying sto_quant function for binarization
binarized_tensor = sto_quant(tensor)
binarized_filter = sto_quant(filter_tensor)

print("Binarized Tensor: \n", binarized_tensor)
print("Binarized Filter: \n", binarized_filter)

Binarized Tensor: 
 tensor([[ 1., -1.,  1.,  1.],
        [ 1.,  1., -1.,  1.],
        [ 1., -1.,  1., -1.],
        [ 1.,  1., -1.,  1.]])
Binarized Filter: 
 tensor([[-1.,  1.],
        [ 1.,  1.]])
