# Optimise ISM for the Basset Architecture

Experimental, trying to find the right way to do it for simple architectures.

Architecture: [link](https://github.com/kundajelab/GenoPyT/blob/c84f38dfaa0c986f91383dd7e6278c1cb993498d/src/models/sequence_only/basset.py)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import Counter
from copy import deepcopy

In [266]:
def get_idxs_conv_maxpool(seqlen, kernelsize, padding, maxpool_kernel, change_ranges, 
                          conv_stride=1,
                          maxpool_stride=None,
                          maxpool_ceil_mode=False): 

    # assumes stride==1 for conv and stride=kernel for maxpool 
    # change ranges are BEFORE padding 
    # indexes returned are slices AFTER padding input seqs
    
    if maxpool_ceil_mode==True or conv_stride!=1 or maxpool_stride!=None: 
        # will take extra care, e.g. repeat values in last block  
        raise NotImplementedError 
     
    # raw ranges for each change_range -- this is the input range in which
    # changing the change_range will affect the output
    raw_seq_ranges = [(x-kernelsize+1,y+kernelsize-1) for x,y in change_ranges] 
     
    # re-adjust since there will be `padding` number of zeros in the beginning 
    raw_seq_pad_adjusted = [(x+padding, y+padding) for x,y in raw_seq_ranges] 
     
    range_corrected = [] 
    for x,y in raw_seq_pad_adjusted: 
        # shift around the edges
        if x<0 and y>seqlen+2*padding: # kinda degenerate, required when using for fc layers
            range_corrected.append((0,seqlen+2*padding))
        elif x<0: 
            range_corrected.append((0, y-x)) 
        elif y > seqlen+2*padding: 
            range_corrected.append((x-(y-seqlen-2*padding),seqlen+2*padding)) 
        else: 
            range_corrected.append((x,y)) 

    # the conv output range affected by each input
    conv_out_ranges = [(x,y-kernelsize+1) for x,y in range_corrected] 

    # length of sequence after convolution
    conv_seqlen = seqlen + 2*padding - kernelsize + 1
    
    # shift to the edges of the nearest maxpool block
    mod_shifted = [(maxpool_kernel*(x//maxpool_kernel), 
                   maxpool_kernel*((y-1)//maxpool_kernel+1)) for x,y in conv_out_ranges] 
    # each should be the same size
    maxwidth = max([y-x for x,y in mod_shifted])  

    mod_shifted = [(x,x+maxwidth) if y<=conv_seqlen else (y-maxwidth, y) for x,y in mod_shifted]  
     
    # when ceil_mode==False, this works by ignoring last block [ceil_mode==False also ignores last block]  
    mod_shifted = [(x,y) if y<=conv_seqlen else (x-maxpool_kernel,y-maxpool_kernel) for x,y in mod_shifted]  
    assert([y<=conv_seqlen for _,y in mod_shifted])  
    
    # this would be the output ranges AFTER maxpool
    out_ranges = [(x//maxpool_kernel, y//maxpool_kernel) for x,y in mod_shifted]  
     
    # work back input slices for desired output maxpool ranges
    slice_ranges = [(x,y+kernelsize-1) for x,y in mod_shifted] 
    
    offsets = [x+padding-slice_ranges[i][0] for i,(x,_) in enumerate(change_ranges)] 
    
    return (slice_ranges, offsets), out_ranges 

In [267]:
(s_slices, _), mxp1_out_ranges = get_idxs_conv_maxpool(1000, 19, 9, 3, [(i,i+1) for i in range(1000)])
inp_width = s_slices[0][1]-s_slices[0][0]
print(inp_width)

39


In [268]:
inp_seq = torch.rand(1, 4, 1000).cuda()
inp_seq_perturbed = inp_seq.repeat(1000,1,1)
inp_seq_perturbed[torch.arange(1000), :, torch.arange(1000)] = 0

In [269]:
inp_seq_slices = sum([list(range(*x)) for x in s_slices], [])
inp_seq_slices = torch.tensor(inp_seq_slices).view(-1, inp_width)
inp_seq_slices = inp_seq_slices.unsqueeze(1).repeat(1,4,1).cuda()

In [270]:
inp_seq_perturbed.shape

torch.Size([1000, 4, 1000])

In [271]:
inp_seq_slices.shape

torch.Size([1000, 4, 39])

In [272]:
padded_inp_seq = torch.zeros(1000, 4, 1018).cuda()
padded_inp_seq[:, :, 9:1009] = inp_seq_perturbed

In [273]:
def f():
#     padded_inp_seq = torch.zeros(1000, 4, 1018).cuda()
    padded_inp_seq[:, :, 9:1009] = inp_seq_perturbed

In [10]:
%timeit f()

140 µs ± 56.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [11]:
%timeit torch.gather(padded_inp_seq, 2, inp_seq_slices)

13.9 µs ± 1.45 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


## Conv1 + Maxpool1

In [274]:
class layer1(nn.Module):
    def __init__(self):
        super(layer1, self).__init__()
        self.conv1 = nn.Conv1d(4, 300, 19, stride=1, padding=0).cuda()
        self.bn1 = nn.BatchNorm1d(300).cuda()
        self.maxpool1 = nn.MaxPool1d(3).cuda()
    
    def forward(self, s):
        return self.maxpool1(F.relu(self.bn1(self.conv1(s))))

In [275]:
l1 = layer1().cuda()
l1_w_padding = deepcopy(l1)
l1_w_padding.conv1.padding = (9,)

In [276]:
l1.eval()
l1_w_padding.eval()

layer1(
  (conv1): Conv1d(4, 300, kernel_size=(19,), stride=(1,), padding=(9,))
  (bn1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool1): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)

In [277]:
with torch.no_grad():
    print(l1(torch.gather(padded_inp_seq, 2, inp_seq_slices)).shape)

torch.Size([1000, 300, 7])


In [278]:
def f():
    with torch.no_grad():
        padded_inp_seq[:, :, 9:1009] = inp_seq_perturbed
        l1(torch.gather(padded_inp_seq, 2, inp_seq_slices))

In [17]:
%timeit f()

1.5 ms ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [279]:
with torch.no_grad():
    mxp1_out = l1(torch.gather(padded_inp_seq, 2, inp_seq_slices))
mxp1_out.shape

torch.Size([1000, 300, 7])

### Compare

In [280]:
inp_seq_perturbed.shape

torch.Size([1000, 4, 1000])

In [281]:
with torch.no_grad():
    mxp1_ism_out = l1_w_padding(inp_seq_perturbed)
mxp1_ism_out.shape

torch.Size([1000, 300, 333])

In [282]:
mxp1_out_ranges[500]

(163, 170)

In [283]:
# check if equality holds for all slices
all([torch.all(torch.isclose(mxp1_ism_out[i, :, range(*mxp1_out_ranges[i])], mxp1_out[i]))==True for i in range(1000)])

True

---

In [284]:
def get_slices_and_scatter_mat(out_slices, out_offsets, outlen, num_channels):
    inp_seq_slices = sum([list(range(*x)) for x in out_slices], [])
    inp_seq_slices = torch.tensor(inp_seq_slices)
    inp_seq_slices = inp_seq_slices.unsqueeze(0).repeat(num_channels,1).cuda()
    
    inp_scatter_mat = torch.tensor([list(range(x,x+outlen)) for x in out_offsets]).cuda()
    inp_scatter_mat = inp_scatter_mat.unsqueeze(1).repeat(1,num_channels,1)
    
    return inp_seq_slices, inp_scatter_mat

In [285]:
with torch.no_grad():
    mxp1_out_ref = l1_w_padding(inp_seq)

In [286]:
print(mxp1_out_ref.shape)
mxp1_out_ref = mxp1_out_ref.squeeze(0)
print(mxp1_out_ref.shape)

torch.Size([1, 300, 333])
torch.Size([300, 333])


In [287]:
(mxp1_out_slices, mxp1_out_offsets), mxp2_out_ranges = get_idxs_conv_maxpool(333, 11, 5, 4, mxp1_out_ranges)

In [288]:
conv2_inp_width = mxp1_out_slices[0][1]-mxp1_out_slices[0][0]
print(conv2_inp_width)
all([y-x==conv2_inp_width for x,y in mxp1_out_slices])

30


True

In [289]:
conv2_inp_num_channels = mxp1_out.shape[1]
conv2_inp_num_channels

300

In [290]:
Counter(mxp1_out_offsets)

Counter({5: 12,
         6: 3,
         7: 3,
         8: 3,
         9: 3,
         10: 237,
         11: 237,
         12: 237,
         13: 237,
         14: 3,
         15: 3,
         16: 3,
         17: 3,
         18: 3,
         19: 13})

In [291]:
conv2_inp_seq_slices, conv2_inp_scatter_mat = get_slices_and_scatter_mat(mxp1_out_slices, mxp1_out_offsets, mxp1_out.shape[2], conv2_inp_num_channels)

In [292]:
padded_conv2_inp_seq = torch.zeros(conv2_inp_num_channels, 343).cuda()
padded_conv2_inp_seq[ :, 5:-5] = mxp1_out_ref

In [393]:
def f():
    padded_conv2_inp_seq[ :, 5:-5] = mxp1_out_ref

In [394]:
%timeit f()

12 µs ± 311 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [293]:
padded_conv2_inp_seq.shape

torch.Size([300, 343])

In [30]:
%timeit torch.gather(padded_conv2_inp_seq, 1, conv2_inp_seq_slices).view(conv2_inp_num_channels, 1000, conv2_inp_width).permute(1,0,2).shape

476 µs ± 55.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [86]:
%timeit aligned.scatter_(2, conv2_inp_scatter_mat, mxp1_out)

187 µs ± 10.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [294]:
aligned = torch.gather(padded_conv2_inp_seq, 1, conv2_inp_seq_slices).view(conv2_inp_num_channels, 1000, conv2_inp_width).permute(1,0,2)
aligned = aligned.scatter_(2, conv2_inp_scatter_mat, mxp1_out)

In [295]:
aligned.shape

torch.Size([1000, 300, 30])

## Conv2 + Maxpool2

In [296]:
class layer2(nn.Module):
    def __init__(self):
        super(layer2, self).__init__()
        self.conv = nn.Conv1d(300, 200, 11, stride=1, padding=0).cuda()
        self.bn = nn.BatchNorm1d(200).cuda()
        self.maxpool = nn.MaxPool1d(4).cuda()
    
    def forward(self, s):
        return self.maxpool(F.relu(self.bn(self.conv(s))))

In [297]:
l2 = layer2().cuda()
l2_w_padding = deepcopy(l2)
l2_w_padding.conv.padding = (5,)

In [298]:
l2.eval()
l2_w_padding.eval()

layer2(
  (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,), padding=(5,))
  (bn): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
)

In [299]:
with torch.no_grad():
    print(l2(aligned).shape)

torch.Size([1000, 200, 5])


In [300]:
def f():
    with torch.no_grad():
        padded_conv2_inp_seq[ :, 5:-5] = mxp1_out_ref
        aligned = torch.gather(padded_conv2_inp_seq, 1, conv2_inp_seq_slices).view(300, 1000, 30).permute(1,0,2)
        aligned = aligned.scatter_(2, conv2_inp_scatter_mat, mxp1_out)
        l2(aligned)

In [61]:
%timeit f()

3.62 ms ± 5.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [301]:
with torch.no_grad():
    mxp2_out = l2(aligned)
print(mxp2_out.shape)

torch.Size([1000, 200, 5])


### Compare

In [302]:
with torch.no_grad():
    mxp2_ism_out = l2_w_padding(l1_w_padding(inp_seq_perturbed))
mxp2_ism_out.shape

torch.Size([1000, 200, 83])

In [303]:
mxp2_out_ranges[500]

(39, 44)

In [318]:
# does not seem to hold for smaller atol values
print(all([torch.all(torch.isclose(mxp2_ism_out[i, :, range(*mxp2_out_ranges[i])], mxp2_out[i], atol=1e-6))==True for i in range(1000)]))
print(all([torch.all(torch.isclose(mxp2_out[i], mxp2_ism_out[i, :, range(*mxp2_out_ranges[i])], atol=1e-6))==True for i in range(1000)]))

True
True


---

In [305]:
with torch.no_grad():
    mxp2_out_ref = l2_w_padding(l1_w_padding(inp_seq))

In [306]:
print(mxp2_out_ref.shape)
mxp2_out_ref = mxp2_out_ref.squeeze(0)
print(mxp2_out_ref.shape)

torch.Size([1, 200, 83])
torch.Size([200, 83])


In [308]:
(mxp2_out_slices, mxp2_out_offsets), mxp3_out_ranges = get_idxs_conv_maxpool(83, 7, 4, 4, mxp2_out_ranges)

In [309]:
conv3_inp_width = mxp2_out_slices[0][1]-mxp2_out_slices[0][0]
print(conv3_inp_width)
all([y-x==conv3_inp_width for x,y in mxp2_out_slices])

22


True

In [310]:
conv3_inp_num_channels = mxp2_out.shape[1]
conv3_inp_num_channels

200

In [311]:
Counter(mxp2_out_offsets)

Counter({4: 36,
         5: 12,
         6: 216,
         7: 216,
         8: 216,
         9: 216,
         10: 12,
         11: 12,
         12: 12,
         13: 12,
         14: 40})

In [312]:
conv3_inp_seq_slices, conv3_inp_scatter_mat = get_slices_and_scatter_mat(mxp2_out_slices, mxp2_out_offsets, mxp2_out.shape[2], conv3_inp_num_channels)

In [313]:
padded_conv3_inp_seq = torch.zeros(conv3_inp_num_channels, 91).cuda()
padded_conv3_inp_seq[ :, 4:-4] = mxp2_out_ref

In [314]:
padded_conv3_inp_seq.shape

torch.Size([200, 91])

In [277]:
%timeit torch.gather(padded_conv3_inp_seq, 1, conv3_inp_seq_slices).view(conv3_inp_num_channels, 1000, conv3_inp_width).permute(1,0,2).shape

221 µs ± 22.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [278]:
%timeit aligned.scatter_(2, conv3_inp_scatter_mat, mxp2_out)

176 µs ± 18.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [319]:
aligned = torch.gather(padded_conv3_inp_seq, 1, conv3_inp_seq_slices).view(conv3_inp_num_channels, 1000, conv3_inp_width).permute(1,0,2)
aligned = aligned.scatter_(2, conv3_inp_scatter_mat, mxp2_out)

In [320]:
aligned.shape

torch.Size([1000, 200, 22])

## Conv3 + Maxpool3

In [321]:
class layer3(nn.Module):
    def __init__(self):
        super(layer3, self).__init__()
        self.conv = nn.Conv1d(200, 200, 7, stride=1, padding=0).cuda()
        self.bn = nn.BatchNorm1d(200).cuda()
        self.maxpool = nn.MaxPool1d(4).cuda()
    
    def forward(self, s):
        return self.maxpool(F.relu(self.bn(self.conv(s))))

In [322]:
l3 = layer3().cuda()
l3_w_padding = deepcopy(l3)
l3_w_padding.conv.padding = (4,)

In [323]:
l3.eval()
l3_w_padding.eval()

layer3(
  (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,), padding=(4,))
  (bn): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
)

In [324]:
with torch.no_grad():
    print(l3(aligned).shape)

torch.Size([1000, 200, 4])


In [325]:
def f():
    with torch.no_grad():
        padded_conv3_inp_seq[ :, 4:-4] = mxp2_out_ref
        aligned = torch.gather(padded_conv3_inp_seq, 1, conv3_inp_seq_slices).view(conv3_inp_num_channels, 1000, conv3_inp_width).permute(1,0,2)
        aligned = aligned.scatter_(2, conv3_inp_scatter_mat, mxp2_out)
        l3(aligned)

In [326]:
%timeit f()

2.4 ms ± 2.79 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [327]:
with torch.no_grad():
    mxp3_out = l3(aligned)
print(mxp3_out.shape)

torch.Size([1000, 200, 4])


### Compare

In [390]:
with torch.no_grad():
    x = l2_w_padding(l1_w_padding(inp_seq_perturbed))

In [391]:
def f():
    with torch.no_grad():
        l3_w_padding(x)

In [392]:
%timeit f()

7.88 ms ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [329]:
with torch.no_grad():
    mxp3_ism_out = l3_w_padding(l2_w_padding(l1_w_padding(inp_seq_perturbed)))
mxp3_ism_out.shape

torch.Size([1000, 200, 21])

In [330]:
mxp3_out_ranges[500]

(9, 13)

In [332]:
# does not seem to hold for smaller atol values
print(all([torch.all(torch.isclose(mxp3_ism_out[i, :, range(*mxp3_out_ranges[i])], mxp3_out[i], atol=1e-6))==True for i in range(1000)]))
print(all([torch.all(torch.isclose(mxp3_out[i], mxp3_ism_out[i, :, range(*mxp3_out_ranges[i])], atol=1e-6))==True for i in range(1000)]))

True
True


---

In [333]:
with torch.no_grad():
    mxp3_out_ref = l3_w_padding(l2_w_padding(l1_w_padding(inp_seq)))

In [334]:
print(mxp3_out_ref.shape)
mxp3_out_ref = mxp3_out_ref.squeeze(0)
print(mxp3_out_ref.shape)

torch.Size([1, 200, 21])
torch.Size([200, 21])


In [335]:
# next layer is FC layer, can be treated as conv with filter=width, no padding, no maxpool (maxpool width 1)
(mxp3_out_slices, mxp3_out_offsets), _ = get_idxs_conv_maxpool(21, 21, 0, 1, mxp3_out_ranges)

In [336]:
conv4_inp_width = mxp3_out_slices[0][1]-mxp3_out_slices[0][0]
print(conv4_inp_width)
all([y-x==conv4_inp_width for x,y in mxp3_out_slices])

21


True

In [337]:
conv4_inp_num_channels = mxp3_out.shape[1]
conv4_inp_num_channels

200

In [338]:
Counter(mxp3_out_offsets)

Counter({0: 96,
         1: 48,
         2: 48,
         3: 48,
         4: 48,
         5: 48,
         6: 48,
         7: 48,
         8: 48,
         9: 48,
         10: 48,
         11: 48,
         12: 48,
         13: 48,
         14: 48,
         15: 48,
         16: 48,
         17: 136})

In [339]:
# don't need to slice, taking the entire mxp3_out_ref
_, conv4_inp_scatter_mat = get_slices_and_scatter_mat(mxp3_out_slices, mxp3_out_offsets, mxp3_out.shape[2], conv4_inp_num_channels)

In [340]:
aligned = mxp3_out_ref.clone().unsqueeze(0).repeat(1000,1,1)

In [341]:
aligned.shape

torch.Size([1000, 200, 21])

In [107]:
%timeit mxp3_out_ref.clone().unsqueeze(0).repeat(1000,1,1)

63.1 µs ± 16.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [342]:
aligned = aligned.scatter_(2, conv4_inp_scatter_mat, mxp3_out)

In [108]:
%timeit aligned.scatter_(2, conv4_inp_scatter_mat, mxp3_out)

115 µs ± 4.78 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [343]:
aligned.shape

torch.Size([1000, 200, 21])

### Check with ref conv3 output

In [352]:
# given the out_ranges, do the remaining values remain the same as output on original sequence?
# this mostly debugs out_ranges
truths = []
for i in range(1000):
    idxs = list(set(range(21)) - set(range(*mxp3_out_ranges[i])))
    truths.append(bool(torch.all(torch.isclose(mxp3_out_ref[:,idxs], mxp3_ism_out[i][:,idxs], atol=1e-6))))
sum(truths)

# atol < 1e-6 doesn't work even when comparing what should be identical!

1000

In [354]:
# is recreated pre fc output same as reference?
torch.all(torch.isclose(aligned, mxp3_ism_out, atol=1e-6))

tensor(True, device='cuda:0')

## FCs (Fully Connected Layers)

In [355]:
class fc_layer(nn.Module):
    def __init__(self):
        super(fc_layer, self).__init__()
        self.fc1 = nn.Linear(4200, 1000)
        self.bn4 = nn.BatchNorm1d(1000)

        self.fc2 = nn.Linear(1000, 1000)
        self.bn5 = nn.BatchNorm1d(1000)

        self.fc3 = nn.Linear(1000, 10)
    
    def forward(self, s):
        s = s.view(-1, 4200)
        s = F.relu(self.bn4(self.fc1(s)))
        s = F.relu(self.bn5(self.fc2(s)))
        s = self.fc3(s)

        return s

In [357]:
fcl = fc_layer().cuda()
fcl.eval()

fc_layer(
  (fc1): Linear(in_features=4200, out_features=1000, bias=True)
  (bn4): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=1000, out_features=1000, bias=True)
  (bn5): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=1000, out_features=10, bias=True)
)

In [361]:
torch.all(torch.isclose(fcl(l3_w_padding(l2_w_padding(l1_w_padding(inp_seq_perturbed)))),
                        fcl(aligned)))

tensor(True, device='cuda:0')

In [366]:
def f():
    with torch.no_grad():
        aligned = mxp3_out_ref.clone().unsqueeze(0).repeat(1000,1,1)
        aligned = aligned.scatter_(2, conv4_inp_scatter_mat, mxp3_out)
        fcl(aligned)

In [368]:
%timeit f()

2.01 ms ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [370]:
def f():
    with torch.no_grad():
        fcl(aligned) # only fc

In [371]:
%timeit f()

1.84 ms ± 4.77 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Standoff

In [386]:
# without optimisations
def normalISM():
    with torch.no_grad():
        return fcl(l3_w_padding(l2_w_padding(l1_w_padding(inp_seq_perturbed))))

In [387]:
normalISM().shape

torch.Size([1000, 10])

In [376]:
%timeit base()

99.3 ms ± 358 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [377]:
def fastISM():
    with torch.no_grad():
        # first conv
        padded_inp_seq[:, :, 9:1009] = inp_seq_perturbed
        x = l1(torch.gather(padded_inp_seq, 2, inp_seq_slices))

        # second conv
        padded_conv2_inp_seq[ :, 5:-5] = mxp1_out_ref
        aligned = torch.gather(padded_conv2_inp_seq, 1, conv2_inp_seq_slices).view(300, 1000, 30).permute(1,0,2)
        aligned = aligned.scatter_(2, conv2_inp_scatter_mat, x)
        x = l2(aligned)

        # third conv
        padded_conv3_inp_seq[ :, 4:-4] = mxp2_out_ref
        aligned = torch.gather(padded_conv3_inp_seq, 1, conv3_inp_seq_slices).view(conv3_inp_num_channels, 1000, conv3_inp_width).permute(1,0,2)
        aligned = aligned.scatter_(2, conv3_inp_scatter_mat, x)
        x = l3(aligned)

        # fc
        aligned = mxp3_out_ref.clone().unsqueeze(0).repeat(1000,1,1)
        aligned = aligned.scatter_(2, conv4_inp_scatter_mat, x)
        x = fcl(aligned)

        return x

In [382]:
fastISM().shape

torch.Size([1000, 10])

In [385]:
%timeit fastISM()

9.52 ms ± 11 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [388]:
torch.all(torch.isclose(normalISM(), fastISM()))

tensor(True, device='cuda:0')

In [389]:
%timeit torch.rand(1000,4,1000)

22 ms ± 844 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
