In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import math

In [21]:
class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

In [22]:
class lookupconv(nn.Conv2d):
    def __init__(self, indices, *args, **kwargs):
        super(lookupconv, self).__init__(*args, **kwargs)
        self.indices=indices
        print "init time"
    
    def forward(self, input):
        x =super(lookupconv, self).forward(input)
        return torch.index_select(x, 1, self.indices)
        
class lookupfc(nn.Linear):
    def __init__(self, indices, *args, **kwargs):
        super(lookupfc, self).__init__(*args, **kwargs)
        self.indices=indices
    
    def forward(self, input):
        x =super(lookupfc, self).forward(input)
        return torch.index_select(x, 1, self.indices)
    
    
class AlexNet_lookup(nn.Module):

    def __init__(self, ochannels, indices, dummyfcidx, num_classes=1000):
        super(AlexNet_lookup, self).__init__()
        self.features = nn.Sequential(
            
            lookupconv(indices[0], 3, ochannels[0], kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            lookupconv(indices[1], 96, ochannels[1], kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            lookupconv(indices[2], 256, ochannels[2], kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            lookupconv(indices[3], 384, ochannels[3], kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            lookupconv(indices[4], 384, ochannels[4], kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

In [27]:
def get_with_context(data,ctx=False):
    if ctx:
        return data.cuda()
    return data

def get_res_with_ctx(inp,ctx=False):
    if ctx:
        return inp.cpu().data.numpy()
    else:
        return inp.data.numpy()
    
    
usecuda = False
shrink=4
ochannels = [96,256,384,384,256]
fcnums = []
indices=[]
onehot_indices=[]
compressed_ochannels=[]

fc_dummy_index= get_with_context(torch.LongTensor(np.random.choice(4096/shrink, 4096)))

for och in ochannels:
    indices.append(get_with_context(torch.LongTensor(np.random.choice(och/shrink, och)), usecuda))
    onehot_indices.append(get_with_context(torch.rand(och, och/shrink),usecuda))
    compressed_ochannels.append(och/shrink)
    
x=get_with_context(torch.rand(128,3,224,224),usecuda)

alexnet        = get_with_context(AlexNet(),usecuda)
alexnet_lookup = get_with_context(AlexNet_lookup(compressed_ochannels, indices, fc_dummy_index),usecuda)
#alexnet_dot    = get_with_context(AlexNet_dot(compressed_ochannels, onehot_indices, fc_dummy_index),usecuda)




init time
init time
init time
init time
init time


In [28]:
begin=time.time()
for i in range(100):
    res1=get_res_with_ctx(alexnet(x), usecuda)
    time1=
print time.time()-begin

begin=time.time()
for i in range(100):
    res2=get_res_with_ctx(alexnet_lookup(x), usecuda)
print time.time()-begin

709.412333965
300.882972956


In [None]:
begin=time.time()
for i in range(5000):
    res3=get_res_with_ctx(alexnet_dot(x), usecuda)
print time.time()-begin

In [None]:
"test"

In [None]:
res2

In [18]:
shrink=2
dummy=torch.rand(7,16,32,32)
dummyfilter=torch.rand(128,16,3,3)
dummyfilter2=torch.rand(16*128/shrink,1,3,3)
indices=torch.Tensor(np.random.choice(128/2,128)).reshape(-1)

c1=F.conv2d(dummy,dummyfilter)
c1.shape

torch.Size([7, 128, 30, 30])

In [17]:
c2_1=F.conv2d(dummy,dummyfilter2,groups=16)
print c2_1.shape
c2_2=torch.index_select

torch.Size([7, 1024, 30, 30])

In [19]:
indices.shape

torch.Size([16])

In [None]:
        self.classifier = nn.Sequential(
            nn.Dropout(),
            lookupfc(dummyfcidx, 256 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            lookupfc(dummyfcidx, 4096, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

In [None]:
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

In [None]:
class AlexNet_lookup(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
        relu1 = nn.ReLU(inplace=True),
        pool1 = nn.MaxPool2d(kernel_size=3, stride=2),
        conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2),
        relu2 = nn.ReLU(inplace=True),
        pool2 = nn.MaxPool2d(kernel_size=3, stride=2),
        conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1),
        relu3 = nn.ReLU(inplace=True),
        conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1),
        relu4 = nn.ReLU(inplace=True),
        conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1),
        relu5 = nn.ReLU(inplace=True),
        pool5 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        
        drop1 = nn.Dropout(),
        fc1 = nn.Linear(256 * 6 * 6, 4096),
        relu6 = nn.ReLU(inplace=True),
        drop2 = nn.Dropout(),
        fc2 = nn.Linear(4096, 4096),
        relu7 = nn.ReLU(inplace=True),
        fc3 = nn.Linear(4096, num_classes),
        
    
    def lookupconv(self,in_data, in_function, indices):
        

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x