Skip to content

Commit

Permalink
Improved formatting and type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
bgrimstad authored and Anders Wenhaug committed Oct 25, 2016
1 parent ce11e30 commit 9fe4f7c
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 44 deletions.
2 changes: 1 addition & 1 deletion python/examples/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Add the SPLINTER directory to the search path, so we can include it
import numpy as np
import matplotlib.pyplot as plt
from os import sys, path, remove
from os import sys, path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
import splinter

Expand Down
2 changes: 1 addition & 1 deletion python/examples/bspline_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Add the SPLINTER directory to the search path, so we can include it
import numpy as np
import matplotlib.pyplot as plt
from os import sys, path, remove
from os import sys, path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
import splinter

Expand Down
6 changes: 3 additions & 3 deletions python/examples/bspline_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
# Add the SPLINTER directory to the search path, so we can include it
import numpy as np
import matplotlib.pyplot as plt
from os import sys, path, remove
from os import sys, path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
import splinter

# Only for dev purposes
# splinter.load("/home/bjarne/Code/C++/splinter4/splinter/bin/Release/libsplinter-3-0.so")
splinter.load("/home/anders/SPLINTER/build/debug/libsplinter-3-0.so")
splinter.load("/home/bjarne/Code/C++/splinter/splinter/bin/Release/libsplinter-3-1.so")
# splinter.load("/home/anders/SPLINTER/build/debug/libsplinter-3-0.so")


# Example with one variable
Expand Down
13 changes: 7 additions & 6 deletions python/splinter/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import splinter
from .function import Function
from .utilities import *
from typing import List


class BSpline(Function):
Expand All @@ -27,7 +28,7 @@ def __init__(self, handle_or_filename):
self._num_variables = splinter._call(splinter._get_handle().splinter_bspline_get_num_variables, self._handle)

@staticmethod
def init_from_param(coefficients, knot_vectors, degrees):
def init_from_param(coefficients: List[float], knot_vectors: List[List[float]], degrees: List[int]) -> 'BSpline':

n = len(degrees)
if len(knot_vectors) != n:
Expand All @@ -44,7 +45,7 @@ def init_from_param(coefficients, knot_vectors, degrees):
dim)
return BSpline(handle)

def get_knot_vectors(self):
def get_knot_vectors(self) -> List[List[float]]:
"""
:return List of knot vectors (of possibly differing lengths)
"""
Expand All @@ -67,7 +68,7 @@ def get_knot_vectors(self):

return knot_vectors

def get_coefficients(self):
def get_coefficients(self) -> List[float]:
"""
:return List of the coefficients of the BSpline
"""
Expand All @@ -76,7 +77,7 @@ def get_coefficients(self):

return c_array_to_list(coefficients_raw, num_coefficients)

def get_control_points(self):
def get_control_points(self) -> List[List[float]]:
"""
Get the matrix with the control points of the BSpline.
:return Matrix (as a list of lists) with getNumVariables+1 columns and len(getCoefficients) rows
Expand All @@ -97,7 +98,7 @@ def get_control_points(self):

return control_points

def get_basis_degrees(self):
def get_basis_degrees(self) -> List[int]:
"""
:return List with the basis degrees of the BSpline
"""
Expand All @@ -106,7 +107,7 @@ def get_basis_degrees(self):

return c_array_to_list(basis_degrees, num_vars)

def insert_knots(self, val, dim, multiplicity=1):
def insert_knots(self, val: float, dim: int, multiplicity: int=1):
"""
Insert knot at 'val' to knot vector for variable 'dim'. The knot is inserted until a knot multiplicity of
'multiplicity' is obtained.
Expand Down
3 changes: 2 additions & 1 deletion python/splinter/bsplineboosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


class BSplineBoosting:
def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, subsample=1.0, alpha=1.0):
def __init__(self, loss: str='ls', learning_rate: float=0.1, n_estimators: int=100, subsample: float=1.0,
alpha: float=1.0):
"""
Class for stochastic gradient boosting with B-spline learners
:param loss: loss function, 'ls' for least squares loss function
Expand Down
65 changes: 38 additions & 27 deletions python/splinter/bsplinebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ctypes import *
from .bspline import BSpline
from .datatable import DataTable
from typing import Union, List


class BSplineBuilder:
Expand All @@ -28,7 +29,8 @@ class KnotSpacing:
def is_valid(value):
return value in range(3)

def __init__(self, x, y, degree=3, smoothing=Smoothing.NONE, alpha=0.1, knot_spacing=KnotSpacing.AS_SAMPLED, num_basis_functions=int(1e6)):
def __init__(self, x, y, degree: int=3, smoothing: int=Smoothing.NONE, alpha: float=0.1,
knot_spacing: int=KnotSpacing.AS_SAMPLED, num_basis_functions: int=int(1e6)):
self._handle = None # Handle for referencing the c side of this object
self._datatable = DataTable(x, y)
self._num_basis_functions = [10 ** 3] * self._datatable.get_num_variables()
Expand All @@ -39,86 +41,95 @@ def __init__(self, x, y, degree=3, smoothing=Smoothing.NONE, alpha=0.1, knot_spa
self._knot_spacing = None
self._num_basis_functions = None

self._handle = splinter._call(splinter._get_handle().splinter_bspline_builder_init, self._datatable._get_handle())
f_handle_init = splinter._get_handle().splinter_bspline_builder_init
self._handle = splinter._call(f_handle_init, self._datatable._get_handle())
self.degree(degree)
self.set_alpha(alpha)
self.smoothing(smoothing)
self.knot_spacing(knot_spacing)
self.num_basis_functions(num_basis_functions)

def degree(self, degrees):
def degree(self, degrees: Union[List[int], int]) -> 'BSplineBuilder':
# If the value is a single number, make it a list of numVariables length
if not isinstance(degrees, list):
degrees = [degrees] * self._datatable.get_num_variables()

if len(degrees) != self._datatable.get_num_variables():
raise ValueError("BSplineBuilder:degree: Inconsistent number of degrees.")
raise ValueError("Inconsistent number of degrees.")

valid_degrees = range(0, 6)
for deg in degrees:
if deg not in valid_degrees:
raise ValueError("BSplineBuilder:degree: Invalid degree: " + str(deg))
raise ValueError("Invalid degree: " + str(deg))

self._degrees = degrees

splinter._call(splinter._get_handle().splinter_bspline_builder_set_degree, self._handle, (c_int * len(self._degrees))(*self._degrees), len(self._degrees))
f_handle = splinter._get_handle().splinter_bspline_builder_set_degree
splinter._call(f_handle, self._handle, (c_int * len(self._degrees))(*self._degrees), len(self._degrees))
return self

def set_alpha(self, new_alpha):
def set_alpha(self, new_alpha: float) -> 'BSplineBuilder':
if new_alpha < 0:
raise ValueError("BSplineBuilder:set_alpha: alpha must be non-negative.")
raise ValueError("'alpha' must be non-negative.")

self._alpha = new_alpha

splinter._call(splinter._get_handle().splinter_bspline_builder_set_alpha, self._handle, self._alpha)
f_handle = splinter._get_handle().splinter_bspline_builder_set_alpha
splinter._call(f_handle, self._handle, self._alpha)
return self

def smoothing(self, smoothing):
def smoothing(self, smoothing: int) -> 'BSplineBuilder':
if not BSplineBuilder.Smoothing.is_valid(smoothing):
raise ValueError("BSplineBuilder::smoothing: Invalid smoothing: " + str(smoothing))
raise ValueError("Invalid smoothing: " + str(smoothing))

self._smoothing = smoothing

splinter._call(splinter._get_handle().splinter_bspline_builder_set_smoothing, self._handle, self._smoothing)
f_handle = splinter._get_handle().splinter_bspline_builder_set_smoothing
splinter._call(f_handle, self._handle, self._smoothing)
return self

def knot_spacing(self, knot_spacing):
def knot_spacing(self, knot_spacing: int) -> 'BSplineBuilder':
if not BSplineBuilder.KnotSpacing.is_valid(knot_spacing):
raise ValueError("BSplineBuilder::knot_spacing: Invalid knotspacing: " + str(knot_spacing))
raise ValueError("Invalid knotspacing: " + str(knot_spacing))

self._knot_spacing = knot_spacing

splinter._call(splinter._get_handle().splinter_bspline_builder_set_knot_spacing, self._handle, self._knot_spacing)
f_handle = splinter._get_handle().splinter_bspline_builder_set_knot_spacing
splinter._call(f_handle, self._handle, self._knot_spacing)
return self

def num_basis_functions(self, num_basis_functions):
def num_basis_functions(self, num_basis_functions: Union[List[int], int]) -> 'BSplineBuilder':
# If the value is a single number, make it a list of num_variables length
if not isinstance(num_basis_functions, list):
num_basis_functions = [num_basis_functions] * self._datatable.get_num_variables()

if len(num_basis_functions) != self._datatable.get_num_variables():
raise ValueError("BSplineBuilder:num_basis_functions: Inconsistent number of degrees.")
raise ValueError("Inconsistent number of degrees.")

for num_basis_function in num_basis_functions:
if not isinstance(num_basis_function, int):
raise ValueError(
"BSplineBuilder:num_basis_functions: Invalid number of basis functions (must be integer): " + str(
num_basis_function))
for num_basis_func in num_basis_functions:
if not isinstance(num_basis_func, int):
raise TypeError("Number of basis functions not integer: " + str(
num_basis_func))
if num_basis_func < 1:
raise ValueError("Number of basis functions < 1: " + str(
num_basis_func))

self._num_basis_functions = num_basis_functions

splinter._call(splinter._get_handle().splinter_bspline_builder_set_num_basis_functions, self._handle,
(c_int * len(self._num_basis_functions))(*self._num_basis_functions),
f_handle = splinter._get_handle().splinter_bspline_builder_set_num_basis_functions
splinter._call(f_handle, self._handle, (c_int * len(self._num_basis_functions))(*self._num_basis_functions),
len(self._num_basis_functions))
return self

# Returns a handle to the created internal BSpline object
def build(self):
bspline_handle = splinter._call(splinter._get_handle().splinter_bspline_builder_build, self._handle)
def build(self) -> BSpline:
f_handle = splinter._get_handle().splinter_bspline_builder_build
bspline_handle = splinter._call(f_handle, self._handle)

return BSpline(bspline_handle)

def __del__(self):
if self._handle is not None:
splinter._call(splinter._get_handle().splinter_bspline_builder_delete, self._handle)
f_handle = splinter._get_handle().splinter_bspline_builder_delete
splinter._call(f_handle, self._handle)
self._handle = None
4 changes: 3 additions & 1 deletion python/splinter/datatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __transfer(self):
# print(str(self.__samples[i*(self.__xDim+1)]) + "," + str(self.__samples[i*(self.__xDim+1)+1]) + " = " + str(self.__samples[i*(self.__xDim+1)+2]))

if self.__num_samples > 0:
splinter._call(splinter._get_handle().splinter_datatable_add_samples_row_major, self.__handle, (c_double * len(self.__samples))(*self.__samples), self.__num_samples, self.__x_dim)
f_handle = splinter._get_handle().splinter_datatable_add_samples_row_major
splinter._call(f_handle, self.__handle, (c_double * len(self.__samples))(*self.__samples),
self.__num_samples, self.__x_dim)

self.__samples = []
self.__num_samples = 0
Expand Down
13 changes: 9 additions & 4 deletions python/splinter/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ def eval(self, x):
x = self._transform_input(x)

num_points = len(x) // self._num_variables
res = splinter._call(splinter._get_handle().splinter_bspline_eval_row_major, self._handle, (c_double * len(x))(*x), len(x))

f_handle = splinter._get_handle().splinter_bspline_eval_row_major
res = splinter._call(f_handle, self._handle, (c_double * len(x))(*x), len(x))

return c_array_to_list(res, num_points)

def eval_jacobian(self, x):
x = self._transform_input(x)

num_points = len(x) // self._num_variables
jac = splinter._call(splinter._get_handle().splinter_bspline_eval_jacobian_row_major, self._handle, (c_double * len(x))(*x), len(x))

f_handle = splinter._get_handle().splinter_bspline_eval_jacobian_row_major
jac = splinter._call(f_handle, self._handle, (c_double * len(x))(*x), len(x))

# Convert from ctypes array to Python list of lists
# jacobians is a list of the jacobians in all evaluated points
Expand All @@ -44,7 +48,9 @@ def eval_hessian(self, x):
x = self._transform_input(x)

num_points = len(x) // self._num_variables
hes = splinter._call(splinter._get_handle().splinter_bspline_eval_hessian_row_major, self._handle, (c_double * len(x))(*x), len(x))

f_handle = splinter._get_handle().splinter_bspline_eval_hessian_row_major
hes = splinter._call(f_handle, self._handle, (c_double * len(x))(*x), len(x))

# Convert from ctypes array to Python list of list of lists
# hessians is a list of the hessians in all points
Expand Down Expand Up @@ -78,7 +84,6 @@ def _transform_input(self, x):

return x


def __del__(self):
if self._handle is not None:
splinter._call(splinter._get_handle().splinter_bspline_delete, self._handle)
Expand Down

0 comments on commit 9fe4f7c

Please sign in to comment.