In [None]:
!pip install wandb



In [None]:
!nvidia-smi

Sat Dec 26 11:18:37 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P8    10W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from torch import nn
from time import time
import torch
from typing import Tuple, Union, Iterable
import numpy as np
from torch.fft import rfftn, irfftn
import torch.nn.functional as f
from functools import partial
from sympy import Rational
import numpy as np
import wandb

torch.manual_seed(42)

<torch._C.Generator at 0x7f4717629c48>

In [None]:
device = torch.device("cuda:0")

# 7D Convolution

In [None]:
def convolution_for_loop(inp,ker):
  start = time()
  N,C,H,W = list(inp.size())
  M,C,R,S = list(ker.size())
  X = H-R+1
  Y = W-S+1
  out = torch.zeros(N,M,X,Y).to(device)
  for n in range(N):
    for m in range(M):
      for x in range(X):
        for y in range(Y):
          out[n][m][x][y] = 0
          for i in range(R):
            for j in range(S):
              for k in range(C):
                a = inp[n][k][x + i][y + j]
                b = ker[m][k][i][j]
                out[n][m][x][y] += a*b
  t = time()-start
  return out,t

# Im2Col

In [None]:
def convolution_im2col(inp,ker):
  
  start = time()

  N, C, H, W = inp.size()
  M, C, R, S = ker.size()

  U = 1

  X = (H - R)//U + 1
  Y = (W - S)//U + 1
  
  def im2col(input_data):
    
    img = input_data
    col = torch.zeros((N, C, R, S, X, Y)).to(device)

    for i in range(R):
        i_max = i + U*X
        for j in range(S):
            j_max = j + U*Y
            col[:, :, i, j, :, :] = img[:, :, i:i_max:U, j:j_max:U]
    
    col = col.permute(0,4,5,1,2,3)
    col = col.reshape(N*X*Y, -1)
    
    return col.T
  
  def col2im(output_data):
    
    outputs = []

    for i in range(N):
      s = output_data[:,i*(X*Y):(i+1)*(X*Y)]
      outputs.append(s.reshape(M,X,Y))

    return torch.stack(outputs,0)
  

  inp_ch = im2col(inp)
  ker_ch = torch.flatten(ker,start_dim=1)

  out = torch.matmul(ker_ch,inp_ch)

  out = col2im(out)

  t = time()-start

  return out, t

# Winograd

In [None]:
from __future__ import print_function
from sympy import symbols, Matrix, Poly, zeros, eye, Indexed, simplify, IndexedBase, init_printing, pprint
from operator import mul
from functools import reduce

def At(a,m,n):
    return Matrix(m, n, lambda i,j: a[i]**j)

def A(a,m,n):
    return At(a, m-1, n).row_insert(m-1, Matrix(1, n, lambda i,j: 1 if j==n-1 else 0))

def T(a,n):
    return Matrix(Matrix.eye(n).col_insert(n, Matrix(n, 1, lambda i,j: -a[i]**n)))

def Lx(a,n):
    x=symbols('x')
    return Matrix(n, 1, lambda i,j: Poly((reduce(mul, ((x-a[k] if k!=i else 1) for k in range(0,n)), 1)).expand(basic=True), x))

def F(a,n):
    return Matrix(n, 1, lambda i,j: reduce(mul, ((a[i]-a[k] if k!=i else 1) for k in range(0,n)), 1))

def Fdiag(a,n):
    f=F(a,n)
    return Matrix(n, n, lambda i,j: (f[i,0] if i==j else 0))

def FdiagPlus1(a,n):
    f = Fdiag(a,n-1)
    f = f.col_insert(n-1, zeros(n-1,1))
    f = f.row_insert(n-1, Matrix(1,n, lambda i,j: (1 if j==n-1 else 0)))
    return f

def L(a,n):
    lx = Lx(a,n)
    f = F(a, n)
    return Matrix(n, n, lambda i,j: lx[i, 0].nth(j)/f[i]).T

def Bt(a,n):
    return L(a,n)*T(a,n)

def B(a,n):
    return Bt(a,n-1).row_insert(n-1, Matrix(1, n, lambda i,j: 1 if j==n-1 else 0))

FractionsInG=0
FractionsInA=1
FractionsInB=2
FractionsInF=3

def cookToomFilter(a,n,r,fractionsIn=FractionsInG):
    alpha = n+r-1
    f = FdiagPlus1(a,alpha)
    if f[0,0] < 0:
        f[0,:] *= -1
    if fractionsIn == FractionsInG:
        AT = A(a,alpha,n).T
        G = (A(a,alpha,r).T*f**(-1)).T
        BT = f * B(a,alpha).T
    elif fractionsIn == FractionsInA:
        BT = f * B(a,alpha).T
        G = A(a,alpha,r)
        AT = (A(a,alpha,n)).T*f**(-1)
    elif fractionsIn == FractionsInB:
        AT = A(a,alpha,n).T
        G = A(a,alpha,r)
        BT = B(a,alpha).T
    else:
        AT = A(a,alpha,n).T
        G = A(a,alpha,r)
        BT = f * B(a,alpha).T
    return (AT,G,BT,f)

In [None]:
def convolution_winograd(input, filter):
    
    N, C, H, W = input.size()
    K, Cprime, r, rprime = filter.size()

    output_size = H - r + 1
    num_values = r + output_size - 2

    polynomials = [0]

    for i in range(num_values - 1):
      polynomials.append(np.float(i + 1))
      polynomials.append(np.float(-1*(i + 1)))

    matrices = cookToomFilter(polynomials[:num_values], output_size, r)

    B_T = torch.FloatTensor(np.array(matrices[2]).astype(np.float64)).to(device)
    B = torch.FloatTensor(np.array(matrices[2].T).astype(np.float64)).to(device)

    A_T = torch.FloatTensor(np.array(matrices[0]).astype(np.float64)).to(device)
    A = torch.FloatTensor(np.array(matrices[0].T).astype(np.float64)).to(device)

    G = torch.FloatTensor(np.array(matrices[1]).astype(np.float64)).to(device)
    G_T = torch.FloatTensor(np.array(matrices[1].T).astype(np.float64)).to(device)

    start = time()
    
    V_list = []

    for n in range(N):
      V_one_list = []
      for c in range(C):
        V_one_list.append(torch.matmul(B_T,torch.matmul(input[n,c],B)))
      V_list.append(torch.stack(V_one_list,0))

    V = torch.stack(V_list,0).to(device)

    U_list = []
    for m in range(K):
      U_one_list = []
      for c in range(C):
        U_one_list.append(torch.matmul(G,torch.matmul(filter[m,c],G_T)))
      U_list.append(torch.stack(U_one_list))

    U = torch.stack(U_list,0).to(device)
    
    out_list = []

    for n in range(N):
      out_1_list = []
      for k in range(K):
        channels_list = []
        for c in range(C):
          channels_list.append(torch.matmul(A_T,torch.matmul(U[k,c]*V[n,c],A)))
        out_1_list.append(torch.sum(torch.stack(channels_list,0),0))
        
      out_list.append(torch.stack(out_1_list,0))

    t = time() - start

    return torch.stack(out_list,0),t

# FFT

In [None]:
def convolution_fft(
    signal,
    kernel,
    padding = 0,
    stride = 1,
    groups = 1,
):
    def complex_matmul(a, b, groups = 1):
        scalar_matmul = partial(torch.einsum, "agc..., gbc... -> agb...")
        a = a.view(a.size(0), groups, -1, *a.shape[2:])
        b = b.view(groups, -1, *b.shape[1:])

        real = scalar_matmul(a.real, b.real) - scalar_matmul(a.imag, b.imag)
        imag = scalar_matmul(a.imag, b.real) + scalar_matmul(a.real, b.imag)
        c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
        c.real, c.imag = real, imag

        return c.view(c.size(0), -1, *c.shape[3:])
    def to_ntuple(val, n):
        if isinstance(val, Iterable):
            out = tuple(val)
            if len(out) == n:
                return out
            else:
                raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.")
        else:
            return n * (val,)
    start = time()
    padding_ = to_ntuple(padding, n=signal.ndim - 2)
    stride_ = to_ntuple(stride, n=signal.ndim - 2)

    signal_padding = [p for p in padding_[::-1] for _ in range(2)]
    signal = f.pad(signal, signal_padding)

    if signal.size(-1) % 2 != 0:
        signal_ = f.pad(signal, [0, 1])
    else:
        signal_ = signal

    kernel_padding = [
        pad
        for i in reversed(range(2, signal_.ndim))
        for pad in [0, signal_.size(i) - kernel.size(i)]
    ]
    padded_kernel = f.pad(kernel, kernel_padding)

    signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
    kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))

    kernel_fr.imag *= -1
    output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
    output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))

    crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
        slice(0, (signal.size(i) - kernel.size(i) + 1), stride_[i - 2])
        for i in range(2, signal.ndim)
    ]
    output = output[crop_slices].contiguous()
    
    t = time()-start
    return output,t

# Tests

In [None]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33momshri[0m (use `wandb login --relogin` to force relogin)


True

In [None]:
checks = [
  [128,1,5,5,1,2,2],
  [128,1,10,10,1,3,3],
  [128,3,10,10,1,3,3],
  [128,1,10,10,4,3,3],
  [128,3,10,10,4,3,3],
  [128,1,10,10,2,3,3],
  [128,3,10,10,2,3,3],
  [128,1,10,10,8,3,3],
  [128,3,10,10,8,3,3],
  [128,3,10,10,1,5,5],
  [128,6,10,10,1,5,5],
  [128,3,10,10,4,5,5],
  [128,6,10,10,4,5,5],
  [128,3,20,20,1,5,5],
  [128,3,20,20,1,5,5],
  [128,1,20,20,4,5,5],
  [128,3,20,20,4,5,5],
  [128,1,20,20,1,8,8],
  [128,3,20,20,1,8,8],
  [128,1,20,20,1,11,11],
  [128,3,20,20,1,11,11],        
]
print(len(checks))

21


In [None]:
for i in checks:

  N,C,H,W,M,R,S = i

  wandb.init(project="conv_time_gpu",name="N:"+str(N)+", C:"+str(C)+", H:"+str(H)+", W:"+str(W)+", M:"+str(M)+", R:"+str(R)+", S:"+str(S),reinit=True)

  conv = nn.Conv2d(C,M,(R,S),bias=False).to(device)

  inp = torch.rand(N,C,H,W).to(device)
  ker = conv.weight.to(device)

  # out_7d = convolution_for_loop(inp,ker)
  out_im2 = convolution_im2col(inp,ker)
  out_win = convolution_winograd(inp,ker)
  out_fft = convolution_fft(inp,ker)

  # wandb.log({"7d":(out_7d[1]*1000)})
  wandb.log({"im2col":(out_im2[1]*1000)})
  wandb.log({"winograd":(out_win[1]*1000)})
  wandb.log({"fft":(out_fft[1]*1000)})

  wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,3.22366
_step,2.0
_runtime,5.0
_timestamp,1608981524.0
winograd,27.50707
fft,3.1476


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,8.26144
_step,2.0
_runtime,1.0
_timestamp,1608981528.0
winograd,25.85363
fft,2.78902


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,7.69949
_step,2.0
_runtime,1.0
_timestamp,1608981533.0
winograd,76.5295
fft,2.62237


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,9.12237
_step,2.0
_runtime,1.0
_timestamp,1608981537.0
winograd,82.2103
fft,2.91634


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,11.64842
_step,2.0
_runtime,1.0
_timestamp,1608981541.0
winograd,163.40232
fft,2.68102


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,6.50287
_step,2.0
_runtime,1.0
_timestamp,1608981546.0
winograd,55.27282
fft,2.63238


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,8.71301
_step,2.0
_runtime,1.0
_timestamp,1608981550.0
winograd,89.26797
fft,2.49052


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,7.06983
_step,2.0
_runtime,1.0
_timestamp,1608981554.0
winograd,138.86929
fft,2.74873


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,6.59347
_step,2.0
_runtime,1.0
_timestamp,1608981559.0
winograd,323.58623
fft,2.57635


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,6.90222
_step,2.0
_runtime,1.0
_timestamp,1608981563.0
winograd,61.34009
fft,1.71494


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,9.61542
_step,2.0
_runtime,1.0
_timestamp,1608981567.0
winograd,96.56239
fft,2.08497


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,7.32708
_step,2.0
_runtime,1.0
_timestamp,1608981572.0
winograd,173.6784
fft,2.19846


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,17.58003
_step,2.0
_runtime,1.0
_timestamp,1608981576.0
winograd,263.23104
fft,1.79696


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,13.13281
_step,2.0
_runtime,1.0
_timestamp,1608981581.0
winograd,52.96469
fft,4.21071


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,8.87489
_step,2.0
_runtime,2.0
_timestamp,1608981586.0
winograd,54.32057
fft,1.78719


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,17.32707
_step,2.0
_runtime,1.0
_timestamp,1608981590.0
winograd,82.58104
fft,3.39818


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,11.97004
_step,2.0
_runtime,2.0
_timestamp,1608981595.0
winograd,160.89058
fft,3.21555


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,15.06686
_step,2.0
_runtime,1.0
_timestamp,1608981599.0
winograd,28.70679
fft,3.1178


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,28.72515
_step,2.0
_runtime,1.0
_timestamp,1608981604.0
winograd,57.30844
fft,1.91665


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,17.33112
_step,2.0
_runtime,1.0
_timestamp,1608981608.0
winograd,30.50041
fft,1.84178


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
im2col,23.11301
_step,2.0
_runtime,1.0
_timestamp,1608981613.0
winograd,50.47059
fft,1.70493


0,1
im2col,▁
_step,▁▅█
_runtime,▁▁▁
_timestamp,▁▁▁
winograd,▁
fft,▁
