In [16]:
import datajoint as dj
import numpy as np

In [17]:
schema = dj.schema('dimitri_ex1', locals())

In [18]:
@schema
class Experiment(dj.Manual):
    definition = """
    # A generic experiment.
    experiment : int
    ----
    """

    
@schema
class Set(dj.Imported):
    definition = """
    # A way to group datasets per experiment
    -> Experiment
    -----
    """ 

    class DataPoint(dj.Part):
        definition = """
        # Raw collected data.
        -> Set
        datapoint : int 
        -----
        x : float
        y : float 
        """
        
    def _make_tuples(self, key):
        n = 10
        mu = 0
        sigma = .1
        
        self.insert1(key)
        self.DataPoint().insert((
            dict(key, 
                 datapoint=i, 
                 x=i, 
                 y=2*i + np.random.normal(mu, sigma)) 
                for i in range(n)))

  
    
@schema
class LinearModel(dj.Computed):
    definition = """
    # fits line a DataCollection. y=mx+b form
    -> Experiment
    -----
    m : float     # Slope
    b : float     # intercept
    """    
    def _make_tuples(self, key):
        X, Y = (Set.DataPoint() & key).fetch['x', 'y']          
        X = np.stack([X, np.ones_like(X)], axis=-1)
        m, b = np.linalg.lstsq(X, Y)[0]        
        self.insert1(dict(key, m=m, b=b))
    
    
@schema
class Stats(dj.Computed):
    definition = """
    # Computes Mean Square Error and R2 for a particular Model
    -> LinearModel
    -----
    mse : float         # The MSE value.
    r2  : float         # R-squared of linear fit
    """    
    def _make_tuples(self, key):
        X, Y =  (Set.DataPoint() & key).fetch['x', 'y']
        m, b = (LinearModel() & key).fetch1['m', 'b']
        yCalc = X*m + b
        self.insert1(
            dict(key, 
                 mse=((Y - yCalc) ** 2).mean(axis=0), 
                 r2=1-np.sum((Y - yCalc)**2)/np.sum((Y - np.mean(Y))**2)))

        
#schema.spawn_missing_classes()

In [24]:
# Generate some data
Experiment().insert(([1],[2],[3]), skip_duplicates=True)
Set().populate()
LinearModel().populate()
Stats().populate()

In [26]:
Set.DataPoint()

experiment,datapoint,x,y
1,0,0.0,0.104471
1,1,1.0,2.04742
1,2,2.0,4.01147
1,3,3.0,6.05707
1,4,4.0,7.91878
1,5,5.0,9.91739
1,6,6.0,11.8811


In [25]:
Stats()

experiment,mse,r2
1,0.00615429,0.99981
2,0.00743637,0.999775
3,0.00550032,0.999833
