In [61]:
import pandas as pd
import numpy as np
from os.path import join, isfile
import pickle
import scipy.linalg
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [62]:
file = 'measurements.csv'
measurements = pd.read_csv(file)

In [63]:
measurements.head()

Unnamed: 0,ramp_dist,drive,voltage,jump_dist
0,10,43,7.9,40
1,11,55,7.9,60
2,6,73,7.9,52
3,6,79,7.9,53
4,6,100,7.9,57


In [66]:
%matplotlib notebook
data = np.c_[measurements.ramp_dist, measurements.jump_dist, measurements.drive] 

mn = np.min(data, axis=0)
mx = np.max(data, axis=0)
X,Y = np.meshgrid(np.linspace(mn[0], mx[0], 20), np.linspace(mn[1], mx[1], 20))
XX = X.flatten()
YY = Y.flatten()

order = 3    # 1: linear, 2: quadratic
if order == 1:
    # best-fit linear plane
    A = np.c_[data[:,0], data[:,1], np.ones(data.shape[0])]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])    # coefficients
    
    # evaluate it on grid
    Z = C[0]*X + C[1]*Y + C[2]
    
    # or expressed using matrix/vector product
    #Z = np.dot(np.c_[XX, YY, np.ones(XX.shape)], C).reshape(X.shape)

elif order == 2:
    # best-fit quadratic curve
    A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**2]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])
    
    # evaluate it on a grid
    Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX*YY, XX**2, YY**2], C).reshape(X.shape)
elif order == 3:
    # best-fit quadratic curve
    A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**3]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])
    
    # evaluate it on a grid
    Z = np.dot(np.c_[np.ones(XX.shape),
                     XX, YY, XX*YY, XX**2, YY**2, (XX**2) * (YY**2), XX**3, YY**3], C).reshape(X.shape)
    
# plot points and fitted surface
fig = plt.figure(figsize=(10, 8))
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2)
ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50)
plt.xlabel('Ramp Distance')
plt.ylabel('Jump Distance')
ax.set_zlabel('Drive Speed')
ax.axis('equal')
ax.axis('tight')
ax.set_zlim(30, 110)
plt.show()

ValueError: shapes (400,9) and (6,) not aligned: 9 (dim 1) != 6 (dim 0)

In [45]:
print(C)

[-7.37226661e+01  8.14133625e+00  3.35467997e+00 -3.97289114e-01
  3.81501847e-01  1.37040507e-02]


In [57]:
def temp(ramp_dist, jump_dist):
    return np.dot(np.c_[1, ramp_dist, jump_dist, ramp_dist*jump_dist, ramp_dist**2, jump_dist**2], C)
print(temp(9, 54))

[58.48223058]
