In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np

In [2]:
def get_quantized_filters(filters, shrink=16):
    shape=filters.shape
    n_clusters=shape[0]/shrink

    filters_shaped=filters.reshape((shape[0], shape[1] * shape[2] * shape[3]))
    estimator = KMeans(n_clusters=n_clusters)
    estimator.fit(filters_shaped.asnumpy())

    filter_kmean_indexes = estimator.predict(X=filters_shaped.asnumpy())
    filters_quantized = np.array([estimator.cluster_centers_[idx] for idx in filter_kmean_indexes])

    return mx.nd.array(filter_kmean_indexes), mx.nd.array(estimator.cluster_centers_), mx.nd.array(filters_quantized)

          
def get_onehot(data, nclusters, batch_size):
    index_mat = mx.nd.one_hot(data, depth=nclusters).reshape(0, -1)
    return mx.nd.broadcast_axes(mx.nd.expand_dims(index_mat, axis=0), axis=0, size=batch_size)

In [15]:
class Net1(nn.Module):

    def __init__(self, inchannels, ochannels):
        super(Net1, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = torch.nn.Conv2d(in_channels=inchannels,out_channels=ochannels,kernel_size=3)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = self.conv1(x)
        return x
    
class Net2(nn.Module):

    def __init__(self, inchannels, ochannels, shrink=2, indices=None):
        super(Net2, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.indices=indices
        print ochannels/shrink
        self.conv1 = torch.nn.Conv2d(in_channels=inchannels,out_channels=ochannels/shrink,kernel_size=3)
        

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = self.conv1(x)
        x = torch.index_select(x, 1, self.indices)
        return x
    
class Net3(nn.Module):

    def __init__(self, inchannels, ochannels, shrink=2, indices=None):
        super(Net3, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.indices=indices
        self.conv1 = torch.nn.Conv2d(in_channels=inchannels,out_channels=ochannels/shrink,kernel_size=3)
        

    def forward(self, x,):
        # Max pooling over a (2, 2) window
        x = self.conv1(x)
        shape=x.shape
        print shape
        print self.indices.shape
        x = torch.matmul(self.indices,x.reshape((shape[0],shape[1],shape[2]*shape[3] )))
        shape2=x.shape
        return x.reshape(shape2[0], shape2[1],shape[2], shape[3])


def get_with_context(data,ctx=False):
    if ctx:
        return data.cuda()
    return data
    
def get_res_with_ctx(inp,ctx=False):
    if ctx:
        return inp.cpu().data.numpy()
    else:
        return inp.data.numpy()
        
usecuda=False

batch=128
ich=16
och=32
shrink = 2
x=get_with_context(torch.rand(batch,ich,27,27),usecuda)
oshape=(32,och,32,32)
xshape=x.shape

The original: 128 to 256

In [16]:
net=get_with_context(Net1(x.shape[1], och), usecuda)

begin=time.time()
for i in range(10):
    res=get_res_with_ctx(net(x),usecuda)
t1=time.time()-begin
print t1
print res.shape

0.300846099854
(128, 32, 25, 25)


clusterconv using lookup

In [17]:
indices=get_with_context(torch.LongTensor(np.random.choice(och/shrink, och)),usecuda)
net=get_with_context(Net2(x.shape[1], och, shrink, indices),usecuda)

begin=time.time()
for i in range(10):
    res=get_res_with_ctx(net(x),usecuda)
t2=time.time()-begin
print t2
print res.shape

16
0.202553033829
(128, 32, 25, 25)


clusterconv using matmul

In [14]:
speedup=float(t1)/t2
print speedup

1.80984535001


In [None]:
indices2=get_with_context(torch.rand(och, och/shrink),usecuda)
net=get_with_context(Net3(x.shape[1], och, shrink, indices2),usecuda)

begin=time.time()
for i in range(100):
    res=get_res_with_ctx(net(x),usecuda)
t3=time.time()-begin
print t3
print res.shape