Skip to content

Commit

Permalink
gemv
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- committed Sep 4, 2017
1 parent d8610d4 commit ccdda31
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
55 changes: 55 additions & 0 deletions v4/larch4/linalg/gemv.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from scipy.linalg.cython_blas cimport dgemv as _dgemv
cimport numpy as np
cimport cython

cdef _dgemv_strider(
double[:,:] a, int lda,
double[:] x, int incx,
double[:] y, int incy,
double alpha= 1.0,
double beta= 0.0,
int trans_a=0
):
cdef int m, n
cdef char trans_a_

n = a.shape[1]
m = a.shape[0]

if trans_a:
trans_a_ = b'T'
else:
trans_a_ = b'N'

_dgemv(&trans_a_, &m, &n, &alpha, &a[0,0], &lda, &x[0], &incx, &beta, &y[0], &incy)


def _isSorted(x, reverse=False):
import operator
my_operator = operator.ge if reverse else operator.le
return all(my_operator(x[i], x[i + 1])
for i in range(len(x) - 1))


def _fortran_check(z):
if z.flags.c_contiguous:
#print("FC-C+",label)
return z.T, 1 #'T'
elif z.flags.f_contiguous:
#print("FC-F+",label)
return z, 0 # 'N'
elif _isSorted(z.strides): # f-not-contiguous
#print("FC-F-",label)
return z, 0 # 'N'
else: # c-not-contiguous
#print("FC-C-",label)
return z.T, 1 #'T'


def dgemv(alpha,a,x,beta,y):
incy = int(y.strides[0] / y.dtype.itemsize)
incx = int(x.strides[0] / x.dtype.itemsize)
a, trans_a = _fortran_check(a)
lda = int(a.strides[1] / a.dtype.itemsize)
return _dgemv_strider(a,lda,x,incx,y,incy,alpha,beta,trans_a)

54 changes: 54 additions & 0 deletions v4/larch4/linalg/test_gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from .gemv import dgemv

import numpy

def test_gemv_overwrite():
a = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')
b = numpy.asarray(numpy.arange(5),dtype=float, order= 'C')

def pick(aa,bb):
c = numpy.ones((10,), float, order="C")
dgemv(1.0,aa,bb,0,c)
c_ = numpy.dot(aa,bb)
assert (c[:c_.shape[0]]==c_).all()
assert (c[c_.shape[0]:]==1).all()

pick(a,b)
pick(a.T,b)
pick(a[1:4,1:3],b[2:4])
pick(a.T[1:4,1:3],b[2:4])

def test_gemv_addon():
a = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')
b = numpy.asarray(numpy.arange(5),dtype=float, order= 'C')

def pick(aa,bb):
c = numpy.ones((10,), float, order="C")
dgemv(1.0,aa,bb,1,c)
c_ = numpy.dot(aa,bb) + 1
assert (c[:c_.shape[0]]==c_).all()
assert (c[c_.shape[0]:]==1).all()

pick(a,b)
pick(a.T,b)
pick(a[1:4,1:3],b[2:4])
pick(a.T[1:4,1:3],b[2:4])

def test_gemv_offset():
a = numpy.asarray(numpy.arange(25).reshape((5,5)),dtype=float, order= 'C')
b = numpy.asarray(numpy.arange(5),dtype=float, order= 'C')

def pick(aa,bb):
c = numpy.ones((10,), float, order="C")
cc = c[1:]
dgemv(1.0,aa,bb,1,cc)
c_ = numpy.dot(aa,bb) + 1
assert (c[1:c_.shape[0]+1]==c_).all()
assert (c[c_.shape[0]+1:]==1).all()
assert (c[0]==1)

pick(a,b)
pick(a.T,b)
pick(a[1:4,1:3],b[2:4])
pick(a.T[1:4,1:3],b[2:4])

0 comments on commit ccdda31

Please sign in to comment.