In [209]:
import numpy as np
import pandas as pd
import sklearn as sk
import sklearn.linear_model
import sklearn.neural_network
import matplotlib as mpl
import matplotlib.pyplot as plt
import colorsys
import json
import nncolor
import nncolor.data as nc

In [210]:
data = pd.read_csv('./resources/experiment_1_1_combined_edited.csv')
train_data, test_data, val_data = nc.split(data)

In [211]:
def to_hsv(d):
    data_hsv = pd.concat([pd.DataFrame([
        [row['ans'], *colorsys.rgb_to_hsv(row['circle_r'], row['circle_g'], row['circle_b']), 
         colorsys.rgb_to_hsv(*[row['bg_r'],]*3)[2]]],
        columns=['ans', 'circle hue', 'circle sat', 'circle val', 'bg val'])
     for idx, row in data.iterrows()])
    return data_hsv

data_hsv = to_hsv(data)

In [212]:
orange_or_brown = data_hsv[data_hsv['ans'].isin((0, 1))]
orange = data_hsv[data_hsv['ans'].isin((0, ))]
brown = data_hsv[data_hsv['ans'].isin((1,))]
both = data_hsv[data_hsv['ans'].isin((2,))]
neither = data_hsv[data_hsv['ans'].isin((3,))]
X = orange_or_brown[['circle hue', 'circle sat', 'circle val', 'bg val']].to_numpy()
y = orange_or_brown['ans']

def as_X(ptable):
    if 'circle hue' in ptable.columns:
        return ptable[['circle hue', 'circle sat', 'circle val', 'bg val']].to_numpy()
    else:
        return ptable[['circle_r', 'circle_g', 'circle_b', 'bg_r']].to_numpy()

def as_Xy(pdata, class_0, class_1):
    all_classes = class_0 +  class_1
    filtered = pdata[pdata['ans'].isin(all_classes)]
    X = as_X(filtered)
    y = filtered['ans']
    y = y.apply(lambda i : 0 if i in class_0 else 1)
    return X, y
#assert np.array_equal(X_test, X) and np.array_equal(y_test, y)

X_orange = as_X(orange)
X_brown = as_X(brown)
X_both = as_X(both)
X_neither = as_X(neither)

In [213]:
def print_min_max(ds):
    min_ = ds.min()
    max_ = ds.max()
    print('MIN')
    print(min_)
    print('MAX')
    print(max_)

print("Orange or brown")
print_min_max(orange_or_brown)
print("All")
print_min_max(data_hsv)

Orange or brown
MIN
ans           0.000000
circle hue    0.040021
circle sat    0.165457
circle val    0.019313
bg val        0.001038
dtype: float64
MAX
ans           1.000000
circle hue    0.163806
circle sat    0.997000
circle val    0.998380
bg val        0.993696
dtype: float64
All
MIN
ans           0.000000
circle hue    0.040021
circle sat    0.000679
circle val    0.000180
bg val        0.001038
dtype: float64
MAX
ans           3.000000
circle hue    0.176777
circle sat    0.999016
circle val    0.999300
bg val        0.998774
dtype: float64


Can raise the min circle saturation to about 0.15

In [214]:
model = sk.linear_model.LogisticRegression(
    solver='liblinear', class_weight='balanced')
model.fit(X, y)
v = model.coef_[0]
b = model.intercept_
print(f'v: {v}, b: {b}')
y_predict = model.predict(X)
accuracy = sk.metrics.accuracy_score(y, y_predict)
print(f'Accuracy: {accuracy}')

v: [ 0.92163431 -1.25406499 -8.15600133  5.35231535], b: [1.89147702]
Accuracy: 0.9725557461406518


In [215]:
X_orange @ v + b

array([-3.50742769, -4.53711499, -3.4889635 , -0.10072297, -2.54326948,
       -4.33270233, -2.53628902, -0.14546873, -1.74397314, -2.28475024,
       -2.39353635, -5.50675963, -4.58896336, -3.27737705, -0.89845613,
       -0.03254778, -0.78560271, -1.71229847, -3.25309568, -4.95114509,
       -1.16008887, -4.47898323, -4.63276612, -1.01150051, -1.67221891,
       -4.68869937, -1.5598308 , -3.69923627, -2.63672935, -1.32236586,
       -0.78575512, -1.02950949, -4.20272441, -1.08564317, -1.39712764,
       -2.54247123, -1.57742439, -0.35298262, -1.69362263,  0.13550761,
       -5.68643702, -4.86227178, -0.08990765, -0.6638898 , -1.16527421,
       -5.99305255, -0.70746899, -0.64591086, -4.38502337, -1.7051397 ,
       -0.29780371, -1.27113632, -0.37773182, -0.7621408 , -5.35157621,
       -1.99478779, -2.58552965, -0.68070026, -2.76124425, -2.21719708,
       -1.23820471, -0.53129859, -3.43933679, -0.6422066 , -1.84879376,
       -0.72514045, -2.52046534, -4.34273962, -0.27190088, -0.43

In [216]:
X_brown @ v + b

array([ 2.82420497,  2.31529253,  0.40139507,  2.21295973,  0.81040789,
        2.01723929, -0.68408759,  3.22407282,  1.76825433,  1.92527269,
        3.39992244,  2.47358215,  0.75031666,  1.06757259,  1.50301098,
        2.36165206,  2.37493806,  0.88000231,  0.54605501,  4.18834584,
        2.66395665,  4.18735929,  1.08496164,  0.5971429 ,  0.66391943,
        0.81611249,  0.05659493,  3.15387155,  2.69806859,  2.67028897,
        1.01774079,  1.74638619,  2.2341872 ,  0.21044946,  2.4682988 ,
        2.23397317,  0.11456033,  1.50232109,  0.95635803,  3.52947918,
        1.14725992,  2.40485971,  2.88650866,  0.96218756,  3.89759457,
        2.71088804,  0.2763058 ,  0.92044521,  0.32876028,  2.23852783,
        0.48776312,  1.82373287,  1.92393759,  1.12206534,  0.60380509,
        3.53240896,  1.10671515,  1.07416966,  3.05097773,  0.94590352,
        1.73720514,  1.08135808,  2.62322743, -0.40230295, -0.68204147,
        2.15555621,  3.31794159,  0.32883951,  2.60378151,  0.57

In [217]:
X_both @ v + b

array([-1.01144842, -2.28989008,  0.84685019,  0.15899947, -0.11630138,
       -0.76974448,  0.55163944, -0.02372326,  0.01690994,  0.02653075,
        1.77957655,  0.14092813,  0.4008932 , -0.25214739, -0.78812394,
       -0.19591175, -0.07228245, -2.30192784,  0.54128606,  0.31323946,
        0.57244486, -0.02466676,  0.23843738, -0.81865377, -0.04771472,
        0.43439159,  0.22710219, -1.48368935, -0.40005147,  0.13089427,
       -0.6841809 , -0.03802384,  0.16070278, -0.45063607, -0.15742539,
        1.08689989,  0.59707402, -0.26903242,  0.17820465,  0.54439509,
        0.45890577,  0.86576088, -0.07907642])

In [218]:
X_neither @ v + b

array([ 4.9195463 ,  4.73534857, -0.7390965 , -2.52271766,  1.40518317,
        2.35831633, -1.17996371,  4.64953327,  0.06841541, -1.90236152,
        0.07032224,  1.24468785,  2.99838978, -2.54481832,  2.23523904,
       -5.16488444, -5.16490919,  0.40647526, -4.72108246,  1.81226121,
       -2.24307778, -5.40491195,  1.04396155, -0.99622794, -0.9307324 ,
       -0.24893815,  3.98320134, -4.94413887, -5.78528944,  0.51314367,
       -2.21403666, -1.44253292,  3.01545672,  2.31879528,  1.24582879,
       -0.65111435,  4.94400834, -0.6683185 , -6.14318202,  1.34081299,
        0.43449878, -3.07988997, -3.00829572,  1.00803309, -4.25658486,
       -0.11233586,  3.73784902,  2.71367188,  2.05971526,  2.12594241,
       -0.07526012, -3.44187005, -0.61877567,  3.3686197 , -3.11791059,
       -2.05694051, -4.07829846, -3.00206129,  3.46550883,  3.63788539,
        2.72328111, -5.26632446, -0.90433064,  2.42330944, -1.19194637,
       -5.06550286, -1.49463145, -5.22904633, -4.05092351, -1.93

In [219]:
def classify2():
    X, y = as_Xy(data_hsv, class_0=(3,), class_1=(0,1,2))
    model = sk.linear_model.LogisticRegression(
        solver='liblinear',
        class_weight={0:1, 1: 10})
    model.fit(X, y)
    v = model.coef_[0]
    b = model.intercept_
    print(f'v: {v}, b: {b}')
    y_predict = model.predict(X)
    accuracy = sk.metrics.recall_score(y, y_predict)
    print(f'Recall: {accuracy}')
    print(f'Num exclude: {len(y_predict) - np.count_nonzero(y_predict)}') 
    print(f'Num include: {np.count_nonzero(y_predict)}') 
    exclude = data_hsv[data_hsv['circle sat'] < 0.05]
    print(np.count_nonzero(model.predict(as_X(exclude))))
classify2()

v: [-18.98851248   2.91729522  -1.44684852   1.58152362], b: [1.92274336]
Recall: 1.0
Num exclude: 161
Num include: 1356
10


In [180]:
def classify_circle_rgb_only():
    X, y = as_Xy(data_test, class_0=(0,), class_1=(1,))
    X = X[:,1:4]
    model = sk.linear_model.LogisticRegression(
        solver='liblinear',
        class_weight='balanced')
    model.fit(X, y)
    v = model.coef_[0]
    b = model.intercept_
    print(f'v: {v}, b: {b}')
    y_predict = model.predict(X)
    accuracy = sk.metrics.accuracy_score(y, y_predict)
    print(f'Accuracy: {accuracy}')
classify_circle_rgb_only()

def classify_bg_rgb_only():
    X, y = as_Xy(data, class_0=(0,), class_1=(1,))
    X = X[:, 3:4]
    model = sk.linear_model.LogisticRegression(
        solver='liblinear',
        class_weight='balanced')
    model.fit(X, y)
    v = model.coef_[0]
    b = model.intercept_
    print(f'v: {v}, b: {b}')
    y_predict = model.predict(X)
    accuracy = sk.metrics.accuracy_score(y, y_predict)
    print(f'Accuracy: {accuracy}')
classify_bg_rgb_only()

v: [-7.49972137  0.44580334  4.26445705], b: [0.09141476]
Accuracy: 0.8473413379073756
v: [2.56757375], b: [-1.36848644]
Accuracy: 0.6826758147512865


In [222]:
def mlp_classify():
    model = sklearn.neural_network.MLPClassifier(hidden_layer_sizes=[], 
                                                 max_iter=3000, 
                                                 solver='lbfgs',
                                                 tol=1e-8)
    def toXy(d):
        filtered = d[d['ans'].isin({0, 1, 3})]
        X = filtered[['circle_r', 'circle_g', 'circle_b', 'bg_r']].to_numpy()
        y = filtered[['ans']].to_numpy().squeeze()
        return X, y
    X, y = toXy(train_data)
    model.fit(X, y)
    Xtest, ytest = toXy(test_data)
    y_predict = model.predict(Xtest)
    accuracy = sk.metrics.accuracy_score(ytest, y_predict)
    print(f'Accuracy: {accuracy}')
    return accuracy
mlp_classify()

def mlp_classify_by_circle():
    model = sklearn.neural_network.MLPClassifier(hidden_layer_sizes=[], 
                                                 max_iter=3000, 
                                                 solver='lbfgs',
                                                 tol=1e-8)
    filtered = data[data['ans'].isin({0, 1, 3})]
    X = filtered[['circle_r', 'circle_g', 'circle_b']].to_numpy()
    y = filtered[['ans']].to_numpy().squeeze()
    model.fit(X, y)
    y_predict = model.predict(X)
    accuracy = sk.metrics.accuracy_score(y, y_predict)
    print(f'Accuracy (from circle): {accuracy}')
    return accuracy
mlp_classify_by_circle()

def mlp_classify_by_bg():
    model = sklearn.neural_network.MLPClassifier(hidden_layer_sizes=[], 
                                                 max_iter=3000, 
                                                 solver='lbfgs',
                                                 tol=1e-8)
    filtered = data[data['ans'].isin({0, 1, 3})]
    X = filtered[['bg_r']].to_numpy()
    y = filtered[['ans']].to_numpy().squeeze()
    model.fit(X, y)
    y_predict = model.predict(X)
    accuracy = sk.metrics.accuracy_score(y, y_predict)
    print(f'Accuracy (from bg): {accuracy}')
    return accuracy
mlp_classify_by_bg()

Accuracy: 0.7355371900826446
Accuracy (from circle): 0.7170963364993216
Accuracy (from bg): 0.6037991858887382


0.6037991858887382

In [242]:
import torch
t1 = torch.zeros((3, 224, 224))
m = torch.mean(t1, dim=(1,2), keepdim=True)
res = torch.cat((t1, m.expand(-1, 224, 224)), dim=0)
res.shape

torch.Size([6, 224, 224])

In [234]:
print(m[:,None,None].shape, t1.shape)

torch.Size([3, 1, 1]) torch.Size([3, 224, 224])
