In [138]:
import numpy as np
import datetime

import matplotlib.pyplot as plt
import seaborn.apionly as sns

plt.style.use('clean')

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [127]:
class InterpolatedLocalVol:
    
    def __init__(self, point_estimates, dates, strikes):
        self.point_estimates = point_estimates
        self.dates = np.array(dates)
        self.strikes = np.array(strikes)
        
    
    def get_box(self, date, strike):
        if (date < self.dates[0] or 
                date > self.dates[-1] or 
                strike < self.strikes[0] or
                strike > self.strikes[-1]):
            return np.nan
        i = np.argmax(self.dates > date)
        j = np.argmax(self.strikes > strike)
        
        if date == self.dates[-1]:
            i = len(self.dates)-1
        if strike == self.strikes[-1]:
            j = len(self.strikes)-1
        return i, j, self.point_estimates[max(0,i-1):i+1,
                                          max(0,j-1):j+1]
    
    def __call__(self, date, strike):
        i, j, box = self.get_box(date, strike)
        date_weights = np.array([(self.dates[i] - date).days,
                                 (date - self.dates[i-1]).days]) \
                        .astype(float)
        date_weights /= date_weights.sum()
        strike_weights = np.array([self.strikes[j] - strike,
                                    strike - self.strikes[j-1]]) \
                        .astype(float)
        strike_weights /= strike_weights.sum()
        return np.dot(np.dot(box, strike_weights), date_weights)

In [4]:
box = InterpolatedLocalVol()

In [132]:
test_points = np.arange(9).reshape((3,3))
test_dates = [datetime.date(2017,1,1), 
              datetime.date(2017,2,1),
              datetime.date(2017,4,1)
             ]
test_strikes = [103, 105, 109]

In [133]:
f = InterpolatedLocalVol(test_points, test_dates, test_strikes)

In [134]:
f.dates

array([datetime.date(2017, 1, 1), datetime.date(2017, 2, 1),
       datetime.date(2017, 4, 1)], dtype=object)

In [135]:
f.strikes

array([103, 105, 109])

In [137]:
f.point_estimates

array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])

In [136]:
%timeit f(datetime.date(2017,4,1), 108)

36.3 µs ± 1.95 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
