In [77]:
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn import datasets
from torch_mas.ciel import Ciel

## Loading the iris dataset

In [78]:
data = datasets.load_iris()

## Splitting the dataset intro train and test

In [79]:
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, random_state=0)

## Learning with Context

In [80]:
from torch_mas.ciel import Ciel

re = Ciel(
    4, 
    1, 
    R=0.8,
    imprecise_th=0.5,
    bad_th=0.0015,
    alpha=0.5,
    memory_length=3,
    n_epochs=5
    )

## Training the agents

In [81]:
import time
t = time.time()
re.fit(X_train, y_train)
tt = time.time() - t
print(f"Total training time: {tt}s")

print("Number of agents created:", re.estimator.agents.n_agents)

Total training time: 0.745032548904419s
Number of agents created: 32


## Computing performance

In [82]:
y_pred = re.predict(X_test)
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        13
           1       0.94      0.94      0.94        16
           2       0.89      0.89      0.89         9

    accuracy                           0.95        38
   macro avg       0.94      0.94      0.94        38
weighted avg       0.95      0.95      0.95        38



# Hyperparameter optimisation

In [83]:
print(re.get_params())

{'R': 0.8, 'alpha': 0.5, 'bad_th': 0.0015, 'imprecise_th': 0.5, 'input_dim': 4, 'memory_length': 3, 'n_epochs': 5, 'output_dim': 1}


In [84]:
import numpy as np
param_grid = [
  {
      'R': np.linspace(0,1,11),
      'imprecise_th': np.linspace(0,1,11),
      'alpha': np.linspace(0,1,11),
      'bad_th': np.linspace(0,1,11)
}]

In [85]:
re = Ciel(
    4, 
    1, 
    R=0.8,
    imprecise_th=0.5,
    bad_th=0.0015,
    alpha=0.5,
    memory_length=3,
    n_epochs=5
    )

In [86]:
grid = GridSearchCV(estimator=re, param_grid=param_grid, scoring='accuracy', verbose=2)

In [87]:
grid.fit(X_train, y_train)

Fitting 5 folds for each of 14641 candidates, totalling 73205 fits
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.0; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.0; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.0; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.0; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.0; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.1; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.1; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.1; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.1; total time=   0.7s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.1; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0, imprecise_th=0.2; total time=   0.6s
[CV] END .....R=0.0, alpha=0.0, bad_th=0.0

KeyboardInterrupt: 

In [None]:
print(grid.best_params_)