In [2]:
import sklearn
import numpy as np

from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.model_selection import train_test_split
mlp = MLPClassifier()

from environment.system import System
from utils import save_pickle, load_pickle

In [4]:
env = System(brick_ip='ev3dev.local', get_state_mode='dict')
env.get_state()

{'bot': (222,),
 'cs': ((192, 183, 235),),
 'top': (98,),
 'ts1': (0,),
 'ts2': (0,)}

if on:
    calc_reward

In [8]:
class StateTrainer:
    def __init__(self, env, class_names, samples_per_class, clf, test_size=0.1):
        self.class_names = class_names
        self.class_map = {
            class_name: number for number, class_name in enumerate(self.class_names)
        }
        self.samples_per_class = samples_per_class
        self.clf = clf
        self.test_size = test_size
        
        self.env = env
        
        self.measurements = None
        self.X = None
        self.y = None
        
    def _rgb(self):
        return self.env.get_state()['cs'][0]
        
    def _button1_pressed(self):
        return self.env.get_state()['ts1'][0]
        
    def _gather_measurements(self):
        final_measurements = []
        for class_n, class_name in enumerate(self.class_names):
            class_measurements = []
            print(class_name)
            for i in range(self.samples_per_class):
                print("Press button 1 to collect a sample.")
                while not self._button1_pressed():
                    pass
                colors = self._rgb()
                class_measurements.append(colors)
                print("Collected sample:", colors, "of class", class_name)
            final_measurements.append(class_measurements)
        return final_measurements

    def get_data_for_training(self):
        if self.measurements is None:
            self.measurements = self._gather_measurements()
        else:
            print("Already have measurements.")
        
        X = np.concatenate(self.measurements)
        y_ = [[self.class_map[class_name]]*self.samples_per_class for class_name in self.class_names]
        y = np.concatenate(y_)
        self.X = X
        self.y = y
        
        return train_test_split(X, y, test_size=self.test_size, stratify=y)
    
    def train(self):
        X_train, X_test, y_train, y_test = self.get_data_for_training()
        self.clf.fit(X_train, y_train)
        return self.clf.score(X_test, y_test)
    
    def save_model(self, name):
        save_pickle(name, self.clf)

In [9]:
class_names = ['off', 'on']
samples_per_class = 20
st = StateTrainer(env, class_names, samples_per_class=samples_per_class, clf=mlp, test_size=0.1)
#print(st.train())
#st.save_model('mlp_on_off.pickle')

In [5]:
mlp = load_pickle('mlp_on_off.pickle')



In [10]:
for i in range(5):
    while not st._button1_pressed():
        pass
    colors = st._rgb()
    print(mlp.predict(np.array(colors).reshape(1,-1)))

[1]
[0]
[0]
[1]
[1]


In [109]:
meas_save = st.measurements

In [125]:
class_names = ['white', 'black', 'dot']
samples_per_class = 20
st = StateTrainer(env, class_names, samples_per_class=samples_per_class, clf=LogisticRegressionCV(), test_size=0.1)
# st.measurements = meas_save
print(st.class_names, st.class_map)
print(st.train())
st.save_model('mlp_white_black.pickle')

['white', 'black', 'dot'] {'white': 0, 'black': 1, 'dot': 2}
white
Press button 1 to collect a sample.
Collected sample: (189, 178, 231) of class white
Press button 1 to collect a sample.
Collected sample: (182, 170, 222) of class white
Press button 1 to collect a sample.
Collected sample: (194, 183, 237) of class white
Press button 1 to collect a sample.
Collected sample: (190, 181, 233) of class white
Press button 1 to collect a sample.
Collected sample: (185, 176, 226) of class white
Press button 1 to collect a sample.
Collected sample: (203, 192, 245) of class white
Press button 1 to collect a sample.
Collected sample: (193, 182, 231) of class white
Press button 1 to collect a sample.
Collected sample: (190, 181, 233) of class white
Press button 1 to collect a sample.
Collected sample: (186, 178, 226) of class white
Press button 1 to collect a sample.
Collected sample: (190, 181, 234) of class white
Press button 1 to collect a sample.
Collected sample: (188, 179, 230) of class whit



1.0


In [11]:
clf = load_pickle('mlp_white_black.pickle')



In [13]:
for i in range(5):
    while not st._button1_pressed():
        pass
    colors = st._rgb()
    print(clf.predict_proba(np.array(colors).reshape(1,-1)))

[[9.92136257e-01 7.78588830e-03 7.78542412e-05]]
[[3.95017874e-30 2.49098904e-02 9.75090110e-01]]
[[1.73590613e-14 6.68894507e-01 3.31105493e-01]]
[[9.99914401e-01 8.51675171e-05 4.31233506e-07]]
[[3.00178007e-12 7.23942901e-01 2.76057099e-01]]
