Skip to content

Commit

Permalink
dgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- committed Sep 1, 2017
1 parent 57c2437 commit 356ca99
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ src/.DS_Store

**.ipynb_checkpoints
pingpong*.bat
gemm.c
Empty file added v4/larch4/__init__.py
Empty file.
Empty file added v4/larch4/linalg/__init__.py
Empty file.
79 changes: 79 additions & 0 deletions v4/larch4/linalg/gemm.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from scipy.linalg.cython_blas cimport dgemm as _dgemm
cimport numpy as np
cimport cython

cpdef _dot_strider(
double[:,:] a, int lda,
double[:,:] b, int ldb,
double[:,:] c, int ldc,
double alpha= 1.0,
double beta= 0.0,
int trans_a=0,
int trans_b=0
):
cdef int m, n, k, i
cdef char trans_a_, trans_b_

if trans_a:
k = a.shape[0]
m = a.shape[1]
trans_a_ = b'T'
else:
k = a.shape[1]
m = a.shape[0]
trans_a_ = b'N'

if trans_b:
n = b.shape[0]
trans_b_ = b'T'
else:
n = b.shape[1]
trans_b_ = b'N'

#print('m,n,k',m,n,k)

_dgemm(&trans_a_, &trans_b_, &m, &n, &k, &alpha, &a[0,0], &lda, &b[0,0], &ldb, &beta, &c[0,0], &ldc)


def _isSorted(x, reverse=False):
import operator
my_operator = operator.ge if reverse else operator.le
return all(my_operator(x[i], x[i + 1])
for i in range(len(x) - 1))


def _fortran_check(z, label):
if z.flags.c_contiguous:
#print("FC-C+",label)
return z.T, 1 #'T'
elif z.flags.f_contiguous:
#print("FC-F+",label)
return z, 0 # 'N'
elif _isSorted(z.strides): # f-not-contiguous
#print("FC-F-",label)
return z, 0 # 'N'
else: # c-not-contiguous
#print("FC-C-",label)
return z.T, 1 #'T'


def dgemm(alpha,a,b,beta,c):
c, trans_c = _fortran_check(c, "MAT-C")
ldc = int(c.strides[1] / c.dtype.itemsize)
if trans_c:
a, trans_a = _fortran_check(a.T, "MAT-At")
b, trans_b = _fortran_check(b.T, "MAT-Bt")
lda = int(a.strides[1] / a.dtype.itemsize)
ldb = int(b.strides[1] / b.dtype.itemsize)
#print("b...",b.strides,ldb,b.shape,'T' if trans_b else 'N')
#print("a...",a.strides,lda,a.shape,'T' if trans_a else 'N')
return _dot_strider(b,ldb,a,lda,c,ldc,alpha,beta,trans_b,trans_a)
else:
a, trans_a = _fortran_check(a, "MAT-A")
b, trans_b = _fortran_check(b, "MAT-B")
lda = int(a.strides[1] / a.dtype.itemsize)
ldb = int(b.strides[1] / b.dtype.itemsize)
#print("a...",a.strides,lda,a.shape,'T' if trans_a else 'N')
#print("b...",b.strides,ldb,b.shape,'T' if trans_b else 'N')
return _dot_strider(a,lda,b,ldb,c,ldc,alpha,beta,trans_a,trans_b)

91 changes: 91 additions & 0 deletions v4/larch4/linalg/test_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from .gemm import dgemm

import numpy

def test_gemm_overwrite():
a = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')
b = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')

def pick(aa,bb):
c = numpy.ones((7,7), float, order="C")
dgemm(1.0,aa,bb,0,c)
c_ = numpy.dot(aa,bb)
assert (c[:c_.shape[0],:c_.shape[1]]==c_).all()
assert (c[c_.shape[0]:,:]==1).all()
assert (c[:,c_.shape[1]:]==1).all()

pick(a,b)
pick(a.T,b.T)
pick(a,b.T)
pick(a.T,b)
pick(a[1:4,:],b[:,2:4])
pick(a.T[1:4,:],b[:,2:4])
pick(a[1:4,:],b.T[:,2:4])
pick(a.T[1:4,:],b.T[:,2:4])

def test_gemm_addon():
a = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')
b = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')

def pick(aa,bb):
c = numpy.ones((7,7), float, order="C")
dgemm(1.0,aa,bb,1,c)
c_ = numpy.dot(aa,bb) + 1
assert (c[:c_.shape[0],:c_.shape[1]]==c_).all()
assert (c[c_.shape[0]:,:]==1).all()
assert (c[:,c_.shape[1]:]==1).all()

pick(a,b)
pick(a.T,b.T)
pick(a,b.T)
pick(a.T,b)
pick(a[1:4,:],b[:,2:4])
pick(a.T[1:4,:],b[:,2:4])
pick(a[1:4,:],b.T[:,2:4])
pick(a.T[1:4,:],b.T[:,2:4])

def test_gemm_offset():
a = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')
b = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')

def pick(aa,bb):
c = numpy.ones((7,7), float, order="C")
cc = c[1:,1:]
dgemm(1.0,aa,bb,0,cc)
c_ = numpy.dot(aa,bb)
assert (c[1:c_.shape[0]+1,1:c_.shape[1]+1]==c_).all()
assert (c[c_.shape[0]+1:,:]==1).all()
assert (c[:,c_.shape[1]+1:]==1).all()
assert (c[0,:]==1).all()
assert (c[:,0]==1).all()

pick(a,b)
pick(a.T,b.T)
pick(a,b.T)
pick(a.T,b)
pick(a[1:4,:],b[:,2:4])
pick(a.T[1:4,:],b[:,2:4])
pick(a[1:4,:],b.T[:,2:4])
pick(a.T[1:4,:],b.T[:,2:4])

def test_gemm_T():
a = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')
b = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')

def pick(aa,bb):
c = numpy.ones((7,7), float, order="C")
cc = c.T
dgemm(1.0,aa,bb,0,cc)
c_ = numpy.dot(aa,bb)
assert (c.T[:c_.shape[0],:c_.shape[1]]==c_).all()
assert (c.T[c_.shape[0]:,:]==1).all()
assert (c.T[:,c_.shape[1]:]==1).all()

pick(a,b)
pick(a.T,b.T)
pick(a,b.T)
pick(a.T,b)
pick(a[1:4,:],b[:,2:4])
pick(a.T[1:4,:],b[:,2:4])
pick(a[1:4,:],b.T[:,2:4])
pick(a.T[1:4,:],b.T[:,2:4])
31 changes: 31 additions & 0 deletions v4/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from setuptools import setup, find_packages
from Cython.Build import cythonize
import numpy
import os


def find_pyx(path='.'):
pyx_files = []
for root, dirs, filenames in os.walk(path):
for fname in filenames:
if fname.endswith('.pyx'):
pyx_files.append(os.path.join(root, fname))
return pyx_files



setup(
name = 'larch4',
ext_modules = cythonize(find_pyx()),
include_dirs=[numpy.get_include()],
packages=find_packages(),
)




###
# python setup.py build_ext --inplace
#


0 comments on commit 356ca99

Please sign in to comment.