Skip to content

Commit

Permalink
complex addition(#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjkkkjjj committed Jul 18, 2022
1 parent f4f0edc commit ca948f4
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 17 deletions.
4 changes: 4 additions & 0 deletions Sources/Matft/core/object/mfarray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ open class MfComplexArray: MfStructuredProtocol{
self.mfdata = MFDATA(ref_realdata: real.mfdata, ref_imagdata: imag.mfdata, offset: real.offsetIndex)
self.mfstructure = MfStructure(shape: real.shape, strides: real.strides)
}
public init (mfdata: MfComplexData, mfstructure: MfStructure){
self.mfdata = mfdata
self.mfstructure = mfstructure
}

deinit {
self.base = nil
Expand Down
33 changes: 33 additions & 0 deletions Sources/Matft/core/protocol/mftypeProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,39 @@ extension Double: MfStorable{
}
}

// DSPSplitComplex
public protocol vDSP_ComplexTypable{
associatedtype T: MfStorable
associatedtype blasType: blas_ComplexTypable

var realp: UnsafeMutablePointer<T> { get set }
var imagp: UnsafeMutablePointer<T> { get set }

init(realp: UnsafeMutablePointer<T>, imagp: UnsafeMutablePointer<T>)
}

extension DSPSplitComplex: vDSP_ComplexTypable{
public typealias blasType = DSPComplex
}
extension DSPDoubleSplitComplex: vDSP_ComplexTypable{
public typealias blasType = DSPDoubleComplex
}

// DSPComplex
public protocol blas_ComplexTypable{
associatedtype T: MfStorable
associatedtype vDSPType: vDSP_ComplexTypable

var real: T { get set }
var imag: T { get set }
init(real: T, imag: T)
}
extension DSPComplex: blas_ComplexTypable{
public typealias vDSPType = DSPSplitComplex
}
extension DSPDoubleComplex: blas_ComplexTypable{
public typealias vDSPType = DSPDoubleSplitComplex
}


/*
Expand Down
57 changes: 57 additions & 0 deletions Sources/Matft/core/util/pointer/withptr.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

import Foundation
import Accelerate

extension MfArray{

Expand Down Expand Up @@ -67,6 +68,26 @@ extension MfComplexArray{

return ret
}

public func withUnsafeMutablevDSPPointer<T: vDSP_ComplexTypable, R>(datatype: T.Type, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{

let ret = try self.withUnsafeMutableStartPointer(datatype: T.T.self){ (ptrrT, ptriT) -> R in
var ptr = T(realp: ptrrT, imagp: ptriT)
return try body(&ptr)
}

return ret
}
internal func withUnsafeMutableblasPointer<T: blas_ComplexTypable, R>(datatype: T.Type, vDSP_func: vDSP_convertcz_func<T.vDSPType, T>, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{

let ret = try self.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){ [unowned self](ptr) -> R in
var arr = Array(repeating: T(real: T.T.zero, imag: T.T.zero), count: self.storedSize)
wrap_vDSP_convertcz(arr.count, ptr, 1, &arr, 1, vDSP_func)
return try body(&arr)
}

return ret
}
}

extension MfData{
Expand All @@ -84,3 +105,39 @@ extension MfData{
return ret
}
}

extension MfComplexData{
public func withUnsafeMutableStartRawPointer<R>(_ body: (UnsafeMutableRawPointer, UnsafeMutableRawPointer) throws -> R) rethrows -> R{
return try body(self.data_real + self.byteOffset, self.data_imag + self.byteOffset)
}
public func withUnsafeMutableStartPointer<T, R>(datatype: T.Type, _ body: (UnsafeMutablePointer<T>, UnsafeMutablePointer<T>) throws -> R) rethrows -> R{
let ret = try self.withUnsafeMutableStartRawPointer{
[unowned self](ptrr, ptri) -> R in
let datarptr = ptrr.bindMemory(to: T.self, capacity: self.storedSize)
let dataiptr = ptri.bindMemory(to: T.self, capacity: self.storedSize)
return try body(datarptr, dataiptr)
}

return ret
}

public func withUnsafeMutablevDSPPointer<T: vDSP_ComplexTypable, R>(datatype: T.Type, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{

let ret = try self.withUnsafeMutableStartPointer(datatype: T.T.self){ (ptrrT, ptriT) -> R in
var ptr = T(realp: ptrrT, imagp: ptriT)
return try body(&ptr)
}

return ret
}
internal func withUnsafeMutableblasPointer<T: blas_ComplexTypable, R>(datatype: T.Type, vDSP_func: vDSP_convertcz_func<T.vDSPType, T>, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{

let ret = try self.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){ [unowned self](ptr) -> R in
var arr = Array(repeating: T(real: T.T.zero, imag: T.T.zero), count: self.storedSize)
wrap_vDSP_convertcz(arr.count, ptr, 1, &arr, 1, vDSP_func)
return try body(&arr)
}

return ret
}
}
15 changes: 10 additions & 5 deletions Sources/Matft/function/static/biop+static.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ extension Matft{
return biopvv_by_vDSP(l_mfarray, r_mfarray, vDSP_func: vDSP_vadd)
case .Double:
return biopvv_by_vDSP(l_mfarray, r_mfarray, vDSP_func: vDSP_vaddD)
/*
case .ComplexFloat:
return biopzvv_by_vDSP(l_mfarray, r_mfarray, datatype: DSPComplex.self, vDSP_func: vDSP_zvadd)
case .ComplexDouble:
return biopzvv_by_vDSP(l_mfarray, r_mfarray, datatype: DSPDoubleComplex.self, vDSP_func: vDSP_zvaddD)*/
}
}

public static func add(_ l_mfarray: MfComplexArray, _ r_mfarray: MfComplexArray) -> MfComplexArray{
switch l_mfarray.storedType{
case .Float:
return biopzvv_by_vDSP(l_mfarray, r_mfarray, vDSP_func: vDSP_zvadd)
case .Double:
return biopzvv_by_vDSP(l_mfarray, r_mfarray, vDSP_func: vDSP_zvaddD)
}
}

/**
Element-wise addition of mfarray and scalar
- parameters:
Expand Down
42 changes: 30 additions & 12 deletions Sources/Matft/library/vDSP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ internal func wrap_vDSP_biopvv<T>(_ size: Int, _ lsrcptr: UnsafePointer<T>, _ ls
vDSP_func(rsrcptr, vDSP_Stride(rsrcStride), lsrcptr, vDSP_Stride(lsrcStride), dstptr, vDSP_Stride(dstStride), vDSP_Length(size))
}

/// Wrapper of vDSP binary operation function
/// - Parameters:
/// - size: A size
/// - lsrcptr: A left source pointer
/// - lsrcStride: A left source stride
/// - rsrcptr: A right source pointer
/// - rsrcStride: A right source stride
/// - dstptr: A destination pointer
/// - dstStride: A destination stride
/// - vDSP_func: The vDSP conversion function
@inline(__always)
internal func wrap_vDSP_biopzvv<T>(_ size: Int, _ lsrcptr: UnsafePointer<T>, _ lsrcStride: Int, _ rsrcptr: UnsafePointer<T>, _ rsrcStride: Int, _ dstptr: UnsafePointer<T>, _ dstStride: Int, _ vDSP_func: vDSP_biopzvv_func<T>){
vDSP_func(rsrcptr, vDSP_Stride(rsrcStride), lsrcptr, vDSP_Stride(lsrcStride), dstptr, vDSP_Stride(dstStride), vDSP_Length(size))
}

/// Wrapper of vDSP binary operation function
/// - Parameters:
Expand Down Expand Up @@ -468,34 +482,38 @@ internal func biopvv_by_vDSP<T: MfStorable>(_ l_mfarray: MfArray, _ r_mfarray: M

return MfArray(mfdata: newdata, mfstructure: newstructure)
}
/*


/// Binary operation by vDSP
/// - Parameters:
/// - l_mfarray: The left mfarray
/// - r_mfarray: The right mfarray
/// - vDSP_func: The vDSP biop function
/// - Returns: The result mfarray
internal func biopzvv_by_vDSP<T, U>(_ l_mfarray: MfArray, _ r_mfarray: MfArray, datatype: T.Type, vDSP_func: vDSP_biopzvv_func<U>) -> MfArray{
internal func biopzvv_by_vDSP<T: vDSP_ComplexTypable>(_ l_mfarray: MfComplexArray, _ r_mfarray: MfComplexArray, vDSP_func: vDSP_biopzvv_func<T>) -> MfComplexArray{
// biggerL: flag whether l is bigger than r
//return mfarray must be either row or column major
/*
let (l_mfarray, r_mfarray, biggerL, retsize) = check_biop_contiguous(l_mfarray, r_mfarray, .Row, convertL: true)
let newdata = MfData(size: retsize, mftype: l_mfarray.mftype)
newdata.withUnsafeMutableStartPointer(datatype: T.self){
let newdata = MfComplexData(size: retsize, mftype: l_mfarray.mftype)*/
let newdata = MfComplexData(size: l_mfarray.size, mftype: l_mfarray.mftype)
let biggerL = true
newdata.withUnsafeMutablevDSPPointer(datatype: T.self){
dstptrT in
l_mfarray.withUnsafeMutableStartPointer(datatype: T.self){
l_mfarray.withUnsafeMutablevDSPPointer(datatype: T.self){
[unowned l_mfarray] (lptr) in
r_mfarray.withUnsafeMutableStartPointer(datatype: T.self){
r_mfarray.withUnsafeMutablevDSPPointer(datatype: T.self){
[unowned r_mfarray] (rptr) in
if biggerL{// l is bigger
for vDSPPrams in OptOffsetParamsSequence(shape: l_mfarray.shape, bigger_strides: l_mfarray.strides, smaller_strides: r_mfarray.strides){
var ltmp = Array(repeating: T.zero, count: <#T##Int#>)
wrap_vDSP_biopvv(vDSPPrams.blocksize, lptr + vDSPPrams.b_offset, vDSPPrams.b_stride, rptr + vDSPPrams.s_offset, vDSPPrams.s_stride, dstptrT + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)

wrap_vDSP_biopzvv(vDSPPrams.blocksize, lptr + vDSPPrams.b_offset, vDSPPrams.b_stride, rptr + vDSPPrams.s_offset, vDSPPrams.s_stride, dstptrT + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)
}
}
else{// r is bigger
for vDSPPrams in OptOffsetParamsSequence(shape: r_mfarray.shape, bigger_strides: r_mfarray.strides, smaller_strides: l_mfarray.strides){
wrap_vDSP_biopvv(vDSPPrams.blocksize, lptr + vDSPPrams.s_offset, vDSPPrams.s_stride, rptr + vDSPPrams.b_offset, vDSPPrams.b_stride, dstptrT + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)
wrap_vDSP_biopzvv(vDSPPrams.blocksize, lptr + vDSPPrams.s_offset, vDSPPrams.s_stride, rptr + vDSPPrams.b_offset, vDSPPrams.b_stride, dstptrT + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)
}
}
}
Expand All @@ -510,8 +528,8 @@ internal func biopzvv_by_vDSP<T, U>(_ l_mfarray: MfArray, _ r_mfarray: MfArray,
newstructure = MfStructure(shape: r_mfarray.shape, strides: r_mfarray.strides)
}

return MfArray(mfdata: newdata, mfstructure: newstructure)
}*/
return MfComplexArray(mfdata: newdata, mfstructure: newstructure)
}


/// Stats operation by vDSP
Expand Down
12 changes: 12 additions & 0 deletions Tests/MatftTests/ComplexTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,16 @@ final class ComplexTests: XCTestCase {
XCTAssertEqual(a.imag, imag)
}
}

func testAdd() {
do {
let real = Matft.arange(start: 0, to: 16, by: 1).reshape([2,2,4])
let imag = Matft.arange(start: 0, to: -16, by: -1).reshape([2,2,4])
let a = MfComplexArray(real: real, imag: imag)

let ret = Matft.add(a, a)
XCTAssertEqual(ret.real, real+real)
XCTAssertEqual(ret.imag, imag+imag)
}
}
}

0 comments on commit ca948f4

Please sign in to comment.