In [141]:
import numpy as np
import matplotlib.pyplot as plt
import math


class conv():
    def __init__(self, H, W, C, N, P, Q, K, R, S):
        self.H = H
        self.W = W
        self.C = C
        self.N = N
        self.P = P
        self.Q = Q
        self.K = K
        self.R = R
        self.S = S
        self.stride = 2
        self.ifm = np.random.rand(H, W, C, N)
        self.filter = np.random.rand(R, S, C, K)
        self.ofm = np.random.rand(self.P, self.Q, self.K, self.N)
    
    # Fprop
    def HWCN_fprop(self):
        self.filterT = np.transpose(self.filter.reshape([self.R*self.S*self.C,self.K]))
        self.sfm = np.empty([self.P, self.Q, self.R, self.S, self.C, self.N])
        # Slice IFM
        for p in range(self.P):
            for q in range(self.Q):
                for r in range(self.R):
                    for s in range(self.S):
                        self.sfm[p,q,r,s] = self.ifm[p*self.stride+r, q*self.stride+s]
        
        # OFM = FilterT*SFM
        self.ofm = np.empty([self.P, self.Q, self.K, self.N])
        self.sfm = self.sfm.reshape([self.P, self.Q, self.R*self.S*self.C,self.N])
        for p in range(self.P):
            for q in range(self.Q):
                for k in range(self.K):
                    # dor(Kx(RSC), (RSC)xN) = KxN
                    self.ofm[p,q] = np.dot(self.filterT, self.sfm[p,q])
        return self.ofm
    
    def NHWxC_fprop(self):
        # (HWC)xN
        self.ifm = self.ifm.reshape(self.H*self.W*self.C,self.N)
        # Nx(HWC)
        self.ifm = np.transpose(self.ifm)
        # NHWxC
        self.ifm = self.ifm.reshape(self.N,self.H,self.W,self.C)

        self.sfm = np.empty([self.N, self.R, self.S, self.P, self.Q, self.C])
        
        # Slice IFM
        for n in range(self.N):
            for r in range(self.R):
                for s in range(self.S):
                    for p in range(self.P):
                        for q in range(self.Q):
                            self.sfm[n,r,s,p,q] = self.ifm[n,p*self.stride+r, q*self.stride+s]
        
        # OFM = SFM*Filter
        self.ofm = np.zeros([self.N, self.P, self.Q, self.K])
        for n in range(self.N):
            for r in range(self.R):
                for s in range(self.S):
                    self.ofm[n] += np.dot(self.sfm[n,r,s], self.filter[r,s])
        return  self.ofm

    def NHWxC_partial_fprop(self):
        # (HWC)xN
        self.ifm = self.ifm.reshape(self.H*self.W*self.C,self.N)
        # Nx(HWC)
        self.ifm = np.transpose(self.ifm)
        # NHWxC
        self.ifm = self.ifm.reshape(self.N,self.H,self.W,self.C)

        
        S0 = self.stride
        S1 = self.S - S0 + 1
        
        self.filter = self.filter.reshape([self.R, S1, S0*self.C,self.K])
        self.sfm = np.empty([self.N, self.R, S1, self.P, self.Q, S0, self.C])
        
        # Slice IFM
        for n in range(self.N):
            for r in range(self.R):
                for s1 in range(S1):
                    for p in range(self.P):
                        for q in range(self.Q):
                            for s0 in range(S0): 
                                self.sfm[n,r,s1,p,q, s0] = self.ifm[n,p*self.stride+r, q*self.stride+s1+s0]
        self.sfm = self.sfm.reshape([self.N, self.R, S1, self.P, self.Q, S0*self.C])
        
        # OFM = SFM*Filter
        self.ofm = np.zeros([self.N, self.P, self.Q, self.K])
        for n in range(self.N):
            for r in range(self.R):
                for s in range(S1):
                    self.ofm[n] += np.dot(self.sfm[n,r,s], self.filter[r,s])
        return  self.ofm

    
    
    # Update
    def HWCN_update(self):
        self.sfm = np.empty([self.P, self.Q, self.R, self.S, self.C, self.N])
        # Slice IFM
        for p in range(self.P):
            for q in range(self.Q):
                for r in range(self.R):
                    for s in range(self.S):
                        self.sfm[p,q,r,s] = self.ifm[p*self.stride+r, q*self.stride+s]
        
        # Filter = *SFM
        self.filter = np.zeros([self.R, self.S, self.C, self.K])        
        for r in range(self.R):
            for s in range(self.S):
                for p in range(self.P):
                    for q in range(self.Q):
                        self.filter[r,s] += np.dot(self.sfm[p,q,r,s], np.transpose(self.ofm[p,q]))
        return self.filter

    
    def NHWxC_update(self):
        # IFM = HWCxN reshape-> (HWC)xN -transpose-> Nx(HWC) -respahe-> NHWxC
        self.ifm = self.ifm.reshape(self.H*self.W*self.C,self.N)
        self.ifm = np.transpose(self.ifm)
        self.ifm = self.ifm.reshape(self.N,self.H,self.W,self.C)

        self.sfm = np.empty([self.N, self.R, self.S, self.P, self.Q, self.C])        
        # Slice IFM NRSPQxC
        for n in range(self.N):
            for r in range(self.R):
                for s in range(self.S):
                    for p in range(self.P):
                        for q in range(self.Q):
                            self.sfm[n,r,s,p,q] = self.ifm[n,p*self.stride+r, q*self.stride+s]
        
        # OFM = PQKxN -reshape-> (PQK)xN -transpose-> Nx(PQK) -reshape-> N(PQ)xK
        self.ofm = self.ofm.reshape(self.P*self.Q*self.K,self.N)
        self.ofm = np.transpose(self.ofm)
        self.ofm = self.ofm.reshape(self.N,self.P*self.Q,self.K)
        
        # Filter = *SFM
        self.filter = np.zeros([self.R, self.S, self.C, self.K])        
        for r in range(self.R):
            for s in range(self.S):
                for n in range(self.N):
                        # SFM[n,r,s] = PQxC -reshape-> (PQ)xC -transpose-> Cx(PQ) 
                        # OFM[n] = (PQ)xK
                        # dot(Cx(PQ), (PQ)xK) = CxK
                        self.filter[r,s] += np.dot(np.transpose(self.sfm[n,r,s].reshape(self.P*self.Q, self.C)), self.ofm[n])
        # filter = RSCxK
        return self.filter    

In [142]:
# Fprops
cfg = conv(16, 16, 128, 32, 8, 8, 64, 2, 2)

In [143]:
hwcn_fprop = cfg.HWCN_fprop()

print("IFM shape HWCxN:", cfg.ifm.shape)
print("Filter shape RSCxK:", cfg.filter.shape)
print("SFM shape PQRSCxN:", cfg.sfm.shape)
print("OFM shape PQKxN:", cfg.ofm.shape)


IFM shape HWCxN: (16, 16, 128, 32)
Filter shape RSCxK: (2, 2, 128, 64)
SFM shape PQRSCxN: (8, 8, 512, 32)
OFM shape PQKxN: (8, 8, 64, 32)


In [144]:

nhwc_fprop = cfg.NHWxC_fprop()

print("IFM shape NHWxC:", cfg.ifm.shape)
print("Filter shape RSCxK:", cfg.filter.shape)
print("SFM shape NRSPQxC:", cfg.sfm.shape)
print("OFM shape NPQxK:", cfg.ofm.shape)


IFM shape NHWxC: (32, 16, 16, 128)
Filter shape RSCxK: (2, 2, 128, 64)
SFM shape NRSPQxC: (32, 2, 2, 8, 8, 128)
OFM shape NPQxK: (32, 8, 8, 64)


In [145]:
# Convert OFM of NHWxC fprop from NPQxK to PQKxN 
nhwc_fprop = np.transpose(nhwc_fprop.reshape(cfg.N, cfg.P*cfg.Q*cfg.K)).reshape(cfg.P, cfg.Q, cfg.K ,cfg.N)
np.allclose(nhwc_fprop, hwcn_fprop)

True

In [146]:
# partial NHWxC, improving slice efficiency

cfg = conv(16, 16, 128, 32, 8, 8, 64, 2, 2)
hwcn_fprop = cfg.HWCN_fprop()
#nhwc_fprop = cfg.NHWxC_fprop()
nhwc_partial_fprop = cfg.NHWxC_partial_fprop()

print("IFM shape NHWxC:", cfg.ifm.shape)
print("Filter shape RSCxK:", cfg.filter.shape)
print("SFM shape NRSPQxC:", cfg.sfm.shape)
print("OFM shape NPQxK:", cfg.ofm.shape)

IFM shape NHWxC: (32, 16, 16, 128)
Filter shape RSCxK: (2, 1, 256, 64)
SFM shape NRSPQxC: (32, 2, 1, 8, 8, 256)
OFM shape NPQxK: (32, 8, 8, 64)


In [147]:
nhwc_partial_fprop = np.transpose(nhwc_partial_fprop.reshape(cfg.N, cfg.P*cfg.Q*cfg.K)).reshape(cfg.P, cfg.Q, cfg.K ,cfg.N)
np.allclose(nhwc_partial_fprop, hwcn_fprop)

True

In [148]:
# Updates
cfg = conv(16, 16, 128, 32, 8, 8, 64, 2, 2)

In [149]:
hwcn_update = cfg.HWCN_update()

print("IFM shape HWCxN:", cfg.ifm.shape)
print("SFM shape PQRSCxN:", cfg.sfm.shape)
print("OFM shape PQKxN:", cfg.ofm.shape)
print("Filter shape RSCxK:", cfg.filter.shape)

IFM shape HWCxN: (16, 16, 128, 32)
SFM shape PQRSCxN: (8, 8, 2, 2, 128, 32)
OFM shape PQKxN: (8, 8, 64, 32)
Filter shape RSCxK: (2, 2, 128, 64)


In [150]:
nhwc_update = cfg.NHWxC_update()

print("IFM shape NHWxC:", cfg.ifm.shape)
print("SFM shape NRSPQxC:", cfg.sfm.shape)
print("OFM shape NPQxK:", cfg.ofm.shape)
print("Filter shape RSCxK:", cfg.filter.shape)


IFM shape NHWxC: (32, 16, 16, 128)
SFM shape NRSPQxC: (32, 2, 2, 8, 8, 128)
OFM shape NPQxK: (32, 64, 64)
Filter shape RSCxK: (2, 2, 128, 64)


In [151]:
np.allclose(nhwc_update, hwcn_update)

True