In [None]:
!pip install tensorflow
!pip install scikit-learn
!pip install keras

In [1]:
"""
    A simple neural network written in Keras (TensorFlow backend) to classify the IRIS data: https://gist.github.com/NiharG15/cd8272c9639941cf8f481a7c4478d525#file-iris-keras-nn-py
"""

import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

from keras.models import Sequential
from keras.layers import Dense
#from keras.optimizers import Adam

from tensorflow.keras.optimizers import Adam 

iris_data = load_iris() # load the iris dataset

print('Example data: ')
print(iris_data.data[:5])
print('Example labels: ')
print(iris_data.target[:5])

x = iris_data.data
y_ = iris_data.target.reshape(-1, 1) # Convert data to a single column

# One Hot encode the class labels
encoder = OneHotEncoder(sparse=False)
y = encoder.fit_transform(y_)
print(y)

# Split the data for training and testing
train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.20)

# Build the model
model = Sequential()
model.add(Dense(10, input_shape=(4,), activation='relu', name='fc1'))
model.add(Dense(10, activation='relu', name='fc2'))
model.add(Dense(3, activation='softmax', name='output'))

# Adam optimizer with learning rate of 0.001
optimizer = Adam(lr=0.001)
model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

print('Neural Network Model Summary: ')
print(model.summary())

# Train the model
model.fit(train_x, train_y, verbose=2, batch_size=5, epochs=200)

# Test on unseen data

results = model.evaluate(test_x, test_y)

print('Final test set loss: {:4f}'.format(results[0]))
print('Final test set accuracy: {:4f}'.format(results[1]))

Example data: 
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]]
Example labels: 
[0 0 0 0 0]
[[1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 

  super(Adam, self).__init__(name, **kwargs)


Neural Network Model Summary: 
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 fc1 (Dense)                 (None, 10)                50        
                                                                 
 fc2 (Dense)                 (None, 10)                110       
                                                                 
 output (Dense)              (None, 3)                 33        
                                                                 
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/200
24/24 - 1s - loss: 1.6028 - accuracy: 0.3583 - 520ms/epoch - 22ms/step
Epoch 2/200
24/24 - 0s - loss: 1.3917 - accuracy: 0.4250 - 24ms/epoch - 994us/step
Epoch 3/200
24/24 - 0s - loss: 1.2646 - accuracy: 0.4333 - 22ms/epoch - 917us/step
Epoch 4/200
24/24 - 0s - loss: 1.1856

Epoch 90/200
24/24 - 0s - loss: 0.1125 - accuracy: 0.9667 - 25ms/epoch - 1ms/step
Epoch 91/200
24/24 - 0s - loss: 0.1154 - accuracy: 0.9667 - 25ms/epoch - 1ms/step
Epoch 92/200
24/24 - 0s - loss: 0.1134 - accuracy: 0.9667 - 25ms/epoch - 1ms/step
Epoch 93/200
24/24 - 0s - loss: 0.1090 - accuracy: 0.9583 - 25ms/epoch - 1ms/step
Epoch 94/200
24/24 - 0s - loss: 0.1127 - accuracy: 0.9667 - 26ms/epoch - 1ms/step
Epoch 95/200
24/24 - 0s - loss: 0.1082 - accuracy: 0.9667 - 25ms/epoch - 1ms/step
Epoch 96/200
24/24 - 0s - loss: 0.1118 - accuracy: 0.9583 - 25ms/epoch - 1ms/step
Epoch 97/200
24/24 - 0s - loss: 0.1164 - accuracy: 0.9583 - 25ms/epoch - 1ms/step
Epoch 98/200
24/24 - 0s - loss: 0.1188 - accuracy: 0.9583 - 22ms/epoch - 921us/step
Epoch 99/200
24/24 - 0s - loss: 0.1056 - accuracy: 0.9583 - 22ms/epoch - 917us/step
Epoch 100/200
24/24 - 0s - loss: 0.1059 - accuracy: 0.9583 - 30ms/epoch - 1ms/step
Epoch 101/200
24/24 - 0s - loss: 0.1058 - accuracy: 0.9583 - 40ms/epoch - 2ms/step
Epoch 102/

Epoch 188/200
24/24 - 0s - loss: 0.0756 - accuracy: 0.9667 - 33ms/epoch - 1ms/step
Epoch 189/200
24/24 - 0s - loss: 0.0749 - accuracy: 0.9667 - 39ms/epoch - 2ms/step
Epoch 190/200
24/24 - 0s - loss: 0.0724 - accuracy: 0.9667 - 38ms/epoch - 2ms/step
Epoch 191/200
24/24 - 0s - loss: 0.0832 - accuracy: 0.9667 - 28ms/epoch - 1ms/step
Epoch 192/200
24/24 - 0s - loss: 0.0764 - accuracy: 0.9667 - 24ms/epoch - 999us/step
Epoch 193/200
24/24 - 0s - loss: 0.0823 - accuracy: 0.9583 - 25ms/epoch - 1ms/step
Epoch 194/200
24/24 - 0s - loss: 0.0732 - accuracy: 0.9667 - 24ms/epoch - 1ms/step
Epoch 195/200
24/24 - 0s - loss: 0.0795 - accuracy: 0.9750 - 26ms/epoch - 1ms/step
Epoch 196/200
24/24 - 0s - loss: 0.0788 - accuracy: 0.9667 - 37ms/epoch - 2ms/step
Epoch 197/200
24/24 - 0s - loss: 0.0746 - accuracy: 0.9667 - 38ms/epoch - 2ms/step
Epoch 198/200
24/24 - 0s - loss: 0.0723 - accuracy: 0.9667 - 36ms/epoch - 2ms/step
Epoch 199/200
24/24 - 0s - loss: 0.0736 - accuracy: 0.9750 - 41ms/epoch - 2ms/step
Ep

In [2]:
iris_data

{'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        [5.4, 3.9, 1.7, 0.4],
        [4.6, 3.4, 1.4, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [4.4, 2.9, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.7, 1.5, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3. , 1.4, 0.1],
        [4.3, 3. , 1.1, 0.1],
        [5.8, 4. , 1.2, 0.2],
        [5.7, 4.4, 1.5, 0.4],
        [5.4, 3.9, 1.3, 0.4],
        [5.1, 3.5, 1.4, 0.3],
        [5.7, 3.8, 1.7, 0.3],
        [5.1, 3.8, 1.5, 0.3],
        [5.4, 3.4, 1.7, 0.2],
        [5.1, 3.7, 1.5, 0.4],
        [4.6, 3.6, 1. , 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [4.8, 3.4, 1.9, 0.2],
        [5. , 3. , 1.6, 0.2],
        [5. , 3.4, 1.6, 0.4],
        [5.2, 3.5, 1.5, 0.2],
        [5.2, 3.4, 1.4, 0.2],
        [4.7, 3.2, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5.2, 4.1, 1.5, 0.1],
  