# Classification Exercise

We'll be working with some California Census Data, we'll be trying to use various features of an individual to predict what class of income they belogn in (>50k or <=50k). 

Here is some information about the data:

<table>
<thead>
<tr>
<th>Column Name</th>
<th>Type</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>age</td>
<td>Continuous</td>
<td>The age of the individual</td>
</tr>
<tr>
<td>workclass</td>
<td>Categorical</td>
<td>The type of employer the  individual has (government,  military, private, etc.).</td>
</tr>
<tr>
<td>fnlwgt</td>
<td>Continuous</td>
<td>The number of people the census  takers believe that observation  represents (sample weight). This  variable will not be used.</td>
</tr>
<tr>
<td>education</td>
<td>Categorical</td>
<td>The highest level of education  achieved for that individual.</td>
</tr>
<tr>
<td>education_num</td>
<td>Continuous</td>
<td>The highest level of education in  numerical form.</td>
</tr>
<tr>
<td>marital_status</td>
<td>Categorical</td>
<td>Marital status of the individual.</td>
</tr>
<tr>
<td>occupation</td>
<td>Categorical</td>
<td>The occupation of the individual.</td>
</tr>
<tr>
<td>relationship</td>
<td>Categorical</td>
<td>Wife, Own-child, Husband,  Not-in-family, Other-relative,  Unmarried.</td>
</tr>
<tr>
<td>race</td>
<td>Categorical</td>
<td>White, Asian-Pac-Islander,  Amer-Indian-Eskimo, Other, Black.</td>
</tr>
<tr>
<td>gender</td>
<td>Categorical</td>
<td>Female, Male.</td>
</tr>
<tr>
<td>capital_gain</td>
<td>Continuous</td>
<td>Capital gains recorded.</td>
</tr>
<tr>
<td>capital_loss</td>
<td>Continuous</td>
<td>Capital Losses recorded.</td>
</tr>
<tr>
<td>hours_per_week</td>
<td>Continuous</td>
<td>Hours worked per week.</td>
</tr>
<tr>
<td>native_country</td>
<td>Categorical</td>
<td>Country of origin of the  individual.</td>
</tr>
<tr>
<td>income</td>
<td>Categorical</td>
<td>"&gt;50K" or "&lt;=50K", meaning  whether the person makes more  than \$50,000 annually.</td>
</tr>
</tbody>
</table>

## Follow the Directions in Bold. If you get stuck, check out the solutions lecture.

### THE DATA

** Read in the census_data.csv data with pandas**

In [43]:
import pandas as pd

In [44]:
data = pd.read_csv('census_data.csv')

In [45]:
data.head()

Unnamed: 0,age,workclass,education,education_num,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,income_bracket
0,39,State-gov,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


** TensorFlow won't be able to understand strings as labels, you'll need to use pandas .apply() method to apply a custom function that converts them to 0s and 1s. This might be hard if you aren't very familiar with pandas, so feel free to take a peek at the solutions for this part.**

** Convert the Label column to 0s and 1s instead of strings.**

In [46]:
data['income_bracket'].unique()

array([' <=50K', ' >50K'], dtype=object)

In [47]:
def label_fix(label):
    if label==' <=50K':
        return 0
    else:
        return 1

In [48]:
data['income_bracket'] = data['income_bracket'].apply(label_fix)

In [49]:
data = data.dropna(axis = 0)

In [50]:
cols_to_norm = ['age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']


In [51]:
data[cols_to_norm] = data[cols_to_norm].apply(lambda x: (x - x.min()) / (x.max() - x.min()))

### Perform a Train Test Split on the Data

In [52]:
from sklearn.model_selection import train_test_split

In [53]:
X = data.drop(['income_bracket'], axis = 1)
Y = data['income_bracket']

In [54]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.3, random_state = 101)

### Create the Feature Columns for tf.esitmator

** Take note of categorical vs continuous values! **

In [55]:
data.columns

Index(['age', 'workclass', 'education', 'education_num', 'marital_status',
       'occupation', 'relationship', 'race', 'gender', 'capital_gain',
       'capital_loss', 'hours_per_week', 'native_country', 'income_bracket'],
      dtype='object')

** Import Tensorflow **

In [56]:
import tensorflow as tf

** Create the tf.feature_columns for the categorical values. Use vocabulary lists or just use hash buckets. **

In [57]:
workclass = tf.feature_column.categorical_column_with_hash_bucket('workclass', hash_bucket_size = data['workclass'].unique().size)
education = tf.feature_column.categorical_column_with_hash_bucket('education', hash_bucket_size = data['education'].unique().size)
marital_status = tf.feature_column.categorical_column_with_hash_bucket('marital_status', hash_bucket_size = data['marital_status'].unique().size)
occupation = tf.feature_column.categorical_column_with_hash_bucket('occupation', hash_bucket_size = data['occupation'].unique().size)
relationship = tf.feature_column.categorical_column_with_hash_bucket('relationship', hash_bucket_size = data['relationship'].unique().size)
race = tf.feature_column.categorical_column_with_hash_bucket('race', hash_bucket_size = data['race'].unique().size)
gender = tf.feature_column.categorical_column_with_vocabulary_list('gender', ['Female', 'Male'])
native_country = tf.feature_column.categorical_column_with_hash_bucket('native_country', hash_bucket_size = data['native_country'].unique().size)


** Create the continuous feature_columns for the continuous values using numeric_column **

In [58]:
age = tf.feature_column.numeric_column('age')
education_num = tf.feature_column.numeric_column('education_num')
capital_gain = tf.feature_column.numeric_column('capital_gain')
capital_loss = tf.feature_column.numeric_column('capital_loss')
hours_week = tf.feature_column.numeric_column('hours_per_week')

** Put all these variables into a single list with the variable name feat_cols **

In [59]:
feat_cols = [age, workclass, education, education_num, marital_status, occupation, relationship, race, gender, capital_gain, capital_loss, hours_week, native_country]

In [60]:
feat_cols

[_NumericColumn(key='age', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),
 _HashedCategoricalColumn(key='workclass', hash_bucket_size=9, dtype=tf.string),
 _HashedCategoricalColumn(key='education', hash_bucket_size=16, dtype=tf.string),
 _NumericColumn(key='education_num', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),
 _HashedCategoricalColumn(key='marital_status', hash_bucket_size=7, dtype=tf.string),
 _HashedCategoricalColumn(key='occupation', hash_bucket_size=15, dtype=tf.string),
 _HashedCategoricalColumn(key='relationship', hash_bucket_size=6, dtype=tf.string),
 _HashedCategoricalColumn(key='race', hash_bucket_size=5, dtype=tf.string),
 _VocabularyListCategoricalColumn(key='gender', vocabulary_list=('Female', 'Male'), dtype=tf.string, default_value=-1, num_oov_buckets=0),
 _NumericColumn(key='capital_gain', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),
 _NumericColumn(key='capital_loss', shape=(1,), default_

### Create Input Function

** Batch_size is up to you. But do make sure to shuffle!**

In [61]:
input_func = tf.estimator.inputs.pandas_input_fn(x=X_train, y=Y_train, batch_size=10, num_epochs=1000, shuffle=True)

#### Create your model with tf.estimator

**Create a LinearClassifier.(If you want to use a DNNClassifier, keep in mind you'll need to create embedded columns out of the cateogrical feature that use strings, check out the previous lecture on this for more info.)**

In [62]:
model = tf.estimator.LinearClassifier(feature_columns = feat_cols,n_classes=2)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'C:\\Users\\4CFA~1\\AppData\\Local\\Temp\\tmpn_d7zw80', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000003A90BEB080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


** Train your model on the data, for at least 5000 steps. **

In [63]:
model.train(input_fn=input_func,steps=10000)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into C:\Users\4CFA~1\AppData\Local\Temp\tmpn_d7zw80\model.ckpt.
INFO:tensorflow:loss = 6.931472, step = 1
INFO:tensorflow:global_step/sec: 133.317
INFO:tensorflow:loss = 4.4819045, step = 101 (0.754 sec)
INFO:tensorflow:global_step/sec: 309.44
INFO:tensorflow:loss = 3.2601838, step = 201 (0.319 sec)
INFO:tensorflow:global_step/sec: 301.716
INFO:tensorflow:loss = 3.1889381, step = 301 (0.338 sec)
INFO:tensorflow:global_step/sec: 307.039
INFO:tensorflow:loss = 3.1431003, step = 401 (0.325 sec)
INFO:tensorflow:global_step/sec: 299.325
INFO:tensorflow:loss = 1.8840094, step = 501 (0.334 sec)
INFO:tensorflow:global_step/sec: 282.796
INFO:tensorflow:loss = 8.5540285, step = 601 (0.349 sec)
INFO:tensorflow:global_step/s

INFO:tensorflow:loss = 3.2419906, step = 8001 (0.364 sec)
INFO:tensorflow:global_step/sec: 274.208
INFO:tensorflow:loss = 8.140498, step = 8101 (0.365 sec)
INFO:tensorflow:global_step/sec: 285.712
INFO:tensorflow:loss = 4.7900696, step = 8201 (0.355 sec)
INFO:tensorflow:global_step/sec: 279.177
INFO:tensorflow:loss = 3.7475827, step = 8301 (0.353 sec)
INFO:tensorflow:global_step/sec: 283.639
INFO:tensorflow:loss = 3.4648356, step = 8401 (0.355 sec)
INFO:tensorflow:global_step/sec: 282.26
INFO:tensorflow:loss = 3.759623, step = 8501 (0.357 sec)
INFO:tensorflow:global_step/sec: 289.913
INFO:tensorflow:loss = 2.0553997, step = 8601 (0.340 sec)
INFO:tensorflow:global_step/sec: 277.921
INFO:tensorflow:loss = 3.8970118, step = 8701 (0.360 sec)
INFO:tensorflow:global_step/sec: 280.038
INFO:tensorflow:loss = 7.6958656, step = 8801 (0.356 sec)
INFO:tensorflow:global_step/sec: 278.301
INFO:tensorflow:loss = 5.511405, step = 8901 (0.359 sec)
INFO:tensorflow:global_step/sec: 281.445
INFO:tensorflo

<tensorflow.python.estimator.canned.linear.LinearClassifier at 0x3a89bebcf8>

### Evaluation

** Create a prediction input function. Remember to only supprt X_test data and keep shuffle=False. **

In [64]:
eval_input_func = tf.estimator.inputs.pandas_input_fn(x=X_test, y=Y_test, batch_size=10, num_epochs=1, shuffle=False)

** Use model.predict() and pass in your input function. This will produce a generator of predictions, which you can then transform into a list, with list() **

In [65]:
predictions = model.predict(eval_input_func)

In [66]:
predicts = list(predictions)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from C:\Users\4CFA~1\AppData\Local\Temp\tmpn_d7zw80\model.ckpt-10000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


** Each item in your list will look like this: **

In [67]:
for pred in predicts:
    print(pred)

{'logits': array([-1.1149942], dtype=float32), 'logistic': array([0.24694102], dtype=float32), 'probabilities': array([0.75305897, 0.24694099], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.6458511], dtype=float32), 'logistic': array([0.02543535], dtype=float32), 'probabilities': array([0.9745646 , 0.02543534], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.0821493], dtype=float32), 'logistic': array([0.2530995], dtype=float32), 'probabilities': array([0.7469005, 0.2530995], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.2329445], dtype=float32), 'logistic': array([0.03794461], dtype=float32), 'probabilities': array([0.9620554 , 0.03794461], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.6824363], dtype=float32), 'log

{'logits': array([-1.6104268], dtype=float32), 'logistic': array([0.16652937], dtype=float32), 'probabilities': array([0.83347064, 0.16652937], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.2234404], dtype=float32), 'logistic': array([0.09766519], dtype=float32), 'probabilities': array([0.9023348 , 0.09766519], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.3306308], dtype=float32), 'logistic': array([0.0886177], dtype=float32), 'probabilities': array([0.9113823 , 0.08861771], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.9859495], dtype=float32), 'logistic': array([0.01823607], dtype=float32), 'probabilities': array([0.9817639 , 0.01823607], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.7933173], dtype=float32), 'l

{'logits': array([-3.6984456], dtype=float32), 'logistic': array([0.02416365], dtype=float32), 'probabilities': array([0.97583634, 0.02416365], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-4.1614857], dtype=float32), 'logistic': array([0.01534524], dtype=float32), 'probabilities': array([0.9846548 , 0.01534524], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.28288996], dtype=float32), 'logistic': array([0.57025456], dtype=float32), 'probabilities': array([0.4297454 , 0.57025456], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-3.9841938], dtype=float32), 'logistic': array([0.01826753], dtype=float32), 'probabilities': array([0.9817324 , 0.01826753], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.0016706], dtype=float32), '

{'logits': array([-1.79393], dtype=float32), 'logistic': array([0.14259157], dtype=float32), 'probabilities': array([0.8574084 , 0.14259157], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.6480994], dtype=float32), 'logistic': array([0.02537967], dtype=float32), 'probabilities': array([0.97462034, 0.02537967], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.3340213], dtype=float32), 'logistic': array([0.08834425], dtype=float32), 'probabilities': array([0.9116558 , 0.08834425], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.165991], dtype=float32), 'logistic': array([0.23758042], dtype=float32), 'probabilities': array([0.76241964, 0.23758042], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.08408737], dtype=float32), 'log

{'logits': array([-2.0281675], dtype=float32), 'logistic': array([0.11627709], dtype=float32), 'probabilities': array([0.8837229 , 0.11627709], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.8343265], dtype=float32), 'logistic': array([0.13772367], dtype=float32), 'probabilities': array([0.8622763 , 0.13772367], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.395679], dtype=float32), 'logistic': array([0.59764904], dtype=float32), 'probabilities': array([0.40235093, 0.59764904], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([1.2266474], dtype=float32), 'logistic': array([0.77323127], dtype=float32), 'probabilities': array([0.22676875, 0.77323127], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-1.0324619], dtype=float32), 'log

{'logits': array([-4.43705], dtype=float32), 'logistic': array([0.01169246], dtype=float32), 'probabilities': array([0.98830754, 0.01169246], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.468239], dtype=float32), 'logistic': array([0.07811495], dtype=float32), 'probabilities': array([0.921885  , 0.07811495], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-4.359063], dtype=float32), 'logistic': array([0.01262884], dtype=float32), 'probabilities': array([0.9873712 , 0.01262884], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.9473076], dtype=float32), 'logistic': array([0.12484724], dtype=float32), 'probabilities': array([0.87515277, 0.12484724], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.17876256], dtype=float32), 'log

{'logits': array([-0.7463452], dtype=float32), 'logistic': array([0.32161817], dtype=float32), 'probabilities': array([0.67838186, 0.3216182 ], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.7909675], dtype=float32), 'logistic': array([0.05781423], dtype=float32), 'probabilities': array([0.9421858 , 0.05781424], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.008078], dtype=float32), 'logistic': array([0.11835737], dtype=float32), 'probabilities': array([0.8816426 , 0.11835738], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.7537111], dtype=float32), 'logistic': array([0.6799868], dtype=float32), 'probabilities': array([0.3200132, 0.6799868], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-2.0974655], dtype=float32), 'logis

{'logits': array([-4.719359], dtype=float32), 'logistic': array([0.00884202], dtype=float32), 'probabilities': array([0.99115795, 0.00884202], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.48002923], dtype=float32), 'logistic': array([0.61775476], dtype=float32), 'probabilities': array([0.3822452 , 0.61775476], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-1.2404858], dtype=float32), 'logistic': array([0.22435144], dtype=float32), 'probabilities': array([0.77564853, 0.22435142], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.535212], dtype=float32), 'logistic': array([0.17723238], dtype=float32), 'probabilities': array([0.8227676 , 0.17723238], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.68993616], dtype=float32), 'l

{'logits': array([11.5354805], dtype=float32), 'logistic': array([0.9999902], dtype=float32), 'probabilities': array([9.776879e-06, 9.999902e-01], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-1.4219325], dtype=float32), 'logistic': array([0.19435881], dtype=float32), 'probabilities': array([0.8056412 , 0.19435881], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-5.222973], dtype=float32), 'logistic': array([0.00536237], dtype=float32), 'probabilities': array([0.99463767, 0.00536237], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.33689904], dtype=float32), 'logistic': array([0.5834371], dtype=float32), 'probabilities': array([0.41656294, 0.5834371 ], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-1.7942808], dtype=float32), 

{'logits': array([-1.7530382], dtype=float32), 'logistic': array([0.1476644], dtype=float32), 'probabilities': array([0.85233563, 0.14766441], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.4206946], dtype=float32), 'logistic': array([0.19455272], dtype=float32), 'probabilities': array([0.8054473 , 0.19455272], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.0252895], dtype=float32), 'logistic': array([0.11657315], dtype=float32), 'probabilities': array([0.88342685, 0.11657315], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.88535094], dtype=float32), 'logistic': array([0.29207018], dtype=float32), 'probabilities': array([0.70792985, 0.29207015], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.693388], dtype=float32), 'l

{'logits': array([0.61468744], dtype=float32), 'logistic': array([0.64900935], dtype=float32), 'probabilities': array([0.3509907 , 0.64900935], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-4.0644917], dtype=float32), 'logistic': array([0.01688183], dtype=float32), 'probabilities': array([0.9831182 , 0.01688183], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.9121891], dtype=float32), 'logistic': array([0.12873513], dtype=float32), 'probabilities': array([0.8712649 , 0.12873511], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.5361304], dtype=float32), 'logistic': array([0.0733638], dtype=float32), 'probabilities': array([0.9266362, 0.0733638], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.2993226], dtype=float32), 'log

{'logits': array([-1.2247608], dtype=float32), 'logistic': array([0.22709973], dtype=float32), 'probabilities': array([0.7729003 , 0.22709973], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.5323858], dtype=float32), 'logistic': array([0.17764488], dtype=float32), 'probabilities': array([0.82235515, 0.17764488], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([1.7496784], dtype=float32), 'logistic': array([0.85191226], dtype=float32), 'probabilities': array([0.14808777, 0.85191226], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-1.1530876], dtype=float32), 'logistic': array([0.23992558], dtype=float32), 'probabilities': array([0.76007444, 0.23992558], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([1.1413671], dtype=float32), 'lo

{'logits': array([-3.0363057], dtype=float32), 'logistic': array([0.04581239], dtype=float32), 'probabilities': array([0.95418763, 0.04581239], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.2654135], dtype=float32), 'logistic': array([0.03677696], dtype=float32), 'probabilities': array([0.963223  , 0.03677695], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.6228557], dtype=float32), 'logistic': array([0.02601163], dtype=float32), 'probabilities': array([0.97398835, 0.02601163], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-4.8521996], dtype=float32), 'logistic': array([0.00775064], dtype=float32), 'probabilities': array([0.99224937, 0.00775064], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.5340266], dtype=float32), '

{'logits': array([-3.6435673], dtype=float32), 'logistic': array([0.02549202], dtype=float32), 'probabilities': array([0.97450805, 0.02549202], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.2616677], dtype=float32), 'logistic': array([0.09434777], dtype=float32), 'probabilities': array([0.9056522 , 0.09434777], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.7653377], dtype=float32), 'logistic': array([0.02263555], dtype=float32), 'probabilities': array([0.9773645 , 0.02263555], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.6237253], dtype=float32), 'logistic': array([0.16469175], dtype=float32), 'probabilities': array([0.83530825, 0.16469175], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.4437606], dtype=float32), '

{'logits': array([-2.2142713], dtype=float32), 'logistic': array([0.09847622], dtype=float32), 'probabilities': array([0.9015237 , 0.09847622], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.9312296], dtype=float32), 'logistic': array([0.01924201], dtype=float32), 'probabilities': array([0.98075795, 0.01924201], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.80664504], dtype=float32), 'logistic': array([0.30860588], dtype=float32), 'probabilities': array([0.6913941 , 0.30860585], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.6968999], dtype=float32), 'logistic': array([0.33249992], dtype=float32), 'probabilities': array([0.6675001 , 0.33249992], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.13931584], dtype=float32), 

{'logits': array([-1.8212478], dtype=float32), 'logistic': array([0.13928421], dtype=float32), 'probabilities': array([0.86071575, 0.1392842 ], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.464123], dtype=float32), 'logistic': array([0.6139918], dtype=float32), 'probabilities': array([0.38600817, 0.6139918 ], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-1.577236], dtype=float32), 'logistic': array([0.17118728], dtype=float32), 'probabilities': array([0.8288127 , 0.17118728], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.5045717], dtype=float32), 'logistic': array([0.18174466], dtype=float32), 'probabilities': array([0.8182553 , 0.18174465], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.8037143], dtype=float32), 'logi

{'logits': array([-1.7412529], dtype=float32), 'logistic': array([0.14915387], dtype=float32), 'probabilities': array([0.8508462 , 0.14915387], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.6567464], dtype=float32), 'logistic': array([0.06557442], dtype=float32), 'probabilities': array([0.93442565, 0.06557442], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.6370473], dtype=float32), 'logistic': array([0.02565449], dtype=float32), 'probabilities': array([0.9743455 , 0.02565449], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.2642624], dtype=float32), 'logistic': array([0.220241], dtype=float32), 'probabilities': array([0.779759  , 0.22024101], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.3996906], dtype=float32), 'lo

{'logits': array([-2.227561], dtype=float32), 'logistic': array([0.09730266], dtype=float32), 'probabilities': array([0.9026973 , 0.09730266], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.7769623], dtype=float32), 'logistic': array([0.14467864], dtype=float32), 'probabilities': array([0.85532135, 0.14467864], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.33043242], dtype=float32), 'logistic': array([0.41813543], dtype=float32), 'probabilities': array([0.5818646, 0.4181354], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-4.5538425], dtype=float32), 'logistic': array([0.01041702], dtype=float32), 'probabilities': array([0.989583  , 0.01041702], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.4709506], dtype=float32), 'lo

{'logits': array([-1.4575076], dtype=float32), 'logistic': array([0.18884882], dtype=float32), 'probabilities': array([0.81115115, 0.18884882], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.2413366], dtype=float32), 'logistic': array([0.09609938], dtype=float32), 'probabilities': array([0.9039006 , 0.09609938], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.8305728], dtype=float32), 'logistic': array([0.02123641], dtype=float32), 'probabilities': array([0.9787636 , 0.02123641], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.245573], dtype=float32), 'logistic': array([0.03748629], dtype=float32), 'probabilities': array([0.9625137 , 0.03748629], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([1.0812918], dtype=float32), 'lo

{'logits': array([-4.5964575], dtype=float32), 'logistic': array([0.00998677], dtype=float32), 'probabilities': array([0.99001324, 0.00998677], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.1746395], dtype=float32), 'logistic': array([0.2360174], dtype=float32), 'probabilities': array([0.76398253, 0.23601738], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.2386346], dtype=float32), 'logistic': array([0.03773744], dtype=float32), 'probabilities': array([0.96226263, 0.03773744], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.5286925], dtype=float32), 'logistic': array([0.07387105], dtype=float32), 'probabilities': array([0.9261289 , 0.07387105], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.9623048], dtype=float32), 'l

{'logits': array([-3.3454564], dtype=float32), 'logistic': array([0.03404427], dtype=float32), 'probabilities': array([0.96595573, 0.03404427], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.4079857], dtype=float32), 'logistic': array([0.03204682], dtype=float32), 'probabilities': array([0.96795315, 0.03204682], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.5278077], dtype=float32), 'logistic': array([0.02853129], dtype=float32), 'probabilities': array([0.9714687 , 0.02853129], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.335361], dtype=float32), 'logistic': array([0.03437782], dtype=float32), 'probabilities': array([0.9656222 , 0.03437782], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.632664], dtype=float32), 'lo

{'logits': array([-1.9494275], dtype=float32), 'logistic': array([0.1246158], dtype=float32), 'probabilities': array([0.8753842, 0.1246158], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.5186121], dtype=float32), 'logistic': array([0.02878727], dtype=float32), 'probabilities': array([0.97121274, 0.02878727], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.5527899], dtype=float32), 'logistic': array([0.3652174], dtype=float32), 'probabilities': array([0.6347826 , 0.36521736], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.1339598], dtype=float32), 'logistic': array([0.10583966], dtype=float32), 'probabilities': array([0.89416033, 0.10583966], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.1630766], dtype=float32), 'logi

{'logits': array([-1.8676386], dtype=float32), 'logistic': array([0.1338152], dtype=float32), 'probabilities': array([0.8661848 , 0.13381518], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.882792], dtype=float32), 'logistic': array([0.0530108], dtype=float32), 'probabilities': array([0.9469892, 0.0530108], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.8751193], dtype=float32), 'logistic': array([0.2941902], dtype=float32), 'probabilities': array([0.7058098 , 0.29419017], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.2654186], dtype=float32), 'logistic': array([0.2200425], dtype=float32), 'probabilities': array([0.7799575, 0.2200425], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([2.9644127], dtype=float32), 'logistic':

{'logits': array([-0.8338355], dtype=float32), 'logistic': array([0.3028347], dtype=float32), 'probabilities': array([0.6971653, 0.3028347], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.2081857], dtype=float32), 'logistic': array([0.09901781], dtype=float32), 'probabilities': array([0.90098214, 0.09901781], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.1596127], dtype=float32), 'logistic': array([0.23873769], dtype=float32), 'probabilities': array([0.7612623 , 0.23873767], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.9737357], dtype=float32), 'logistic': array([0.12198821], dtype=float32), 'probabilities': array([0.87801176, 0.1219882 ], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.690329], dtype=float32), 'logi

{'logits': array([-1.0350376], dtype=float32), 'logistic': array([0.26210862], dtype=float32), 'probabilities': array([0.7378914 , 0.26210862], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.7579353], dtype=float32), 'logistic': array([0.05964006], dtype=float32), 'probabilities': array([0.94035995, 0.05964006], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.8907242], dtype=float32), 'logistic': array([0.2909604], dtype=float32), 'probabilities': array([0.7090396 , 0.29096043], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.0413419], dtype=float32), 'logistic': array([0.26089114], dtype=float32), 'probabilities': array([0.7391088 , 0.26089114], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-5.4431686], dtype=float32), 'l

{'logits': array([-2.095343], dtype=float32), 'logistic': array([0.10955027], dtype=float32), 'probabilities': array([0.8904497 , 0.10955027], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.6365587], dtype=float32), 'logistic': array([0.16293387], dtype=float32), 'probabilities': array([0.8370661 , 0.16293387], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.9343557], dtype=float32), 'logistic': array([0.05048113], dtype=float32), 'probabilities': array([0.94951886, 0.05048113], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([4.6256285], dtype=float32), 'logistic': array([0.9902976], dtype=float32), 'probabilities': array([0.00970244, 0.9902976 ], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([0.7230723], dtype=float32), 'logi

{'logits': array([-4.6253505], dtype=float32), 'logistic': array([0.00970511], dtype=float32), 'probabilities': array([0.99029493, 0.00970511], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.9554008], dtype=float32), 'logistic': array([0.7222001], dtype=float32), 'probabilities': array([0.2778   , 0.7222001], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-1.8775933], dtype=float32), 'logistic': array([0.13266556], dtype=float32), 'probabilities': array([0.8673344 , 0.13266556], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.2934396], dtype=float32), 'logistic': array([0.09166775], dtype=float32), 'probabilities': array([0.9083322 , 0.09166774], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.7588651], dtype=float32), 'logi

{'logits': array([-1.8165002], dtype=float32), 'logistic': array([0.13985436], dtype=float32), 'probabilities': array([0.8601456 , 0.13985436], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([1.8743427], dtype=float32), 'logistic': array([0.8669599], dtype=float32), 'probabilities': array([0.13304004, 0.8669599 ], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([0.4250033], dtype=float32), 'logistic': array([0.6046799], dtype=float32), 'probabilities': array([0.39532015, 0.6046799 ], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-3.1319497], dtype=float32), 'logistic': array([0.04180843], dtype=float32), 'probabilities': array([0.9581916 , 0.04180843], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.16200757], dtype=float32), 'log

{'logits': array([-2.0276701], dtype=float32), 'logistic': array([0.11632821], dtype=float32), 'probabilities': array([0.8836718 , 0.11632821], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.8956497], dtype=float32), 'logistic': array([0.13060164], dtype=float32), 'probabilities': array([0.86939836, 0.13060163], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([1.5464082], dtype=float32), 'logistic': array([0.82439435], dtype=float32), 'probabilities': array([0.17560565, 0.82439435], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-2.991238], dtype=float32), 'logistic': array([0.04782328], dtype=float32), 'probabilities': array([0.9521767 , 0.04782328], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.2482696], dtype=float32), 'lo

{'logits': array([-1.0099282], dtype=float32), 'logistic': array([0.26699388], dtype=float32), 'probabilities': array([0.73300606, 0.26699388], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.3791723], dtype=float32), 'logistic': array([0.03295276], dtype=float32), 'probabilities': array([0.9670473 , 0.03295276], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.22407031], dtype=float32), 'logistic': array([0.55578434], dtype=float32), 'probabilities': array([0.4442156 , 0.55578434], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-2.551722], dtype=float32), 'logistic': array([0.07231088], dtype=float32), 'probabilities': array([0.9276891 , 0.07231087], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.3996994], dtype=float32), 'l

{'logits': array([-2.461126], dtype=float32), 'logistic': array([0.07862872], dtype=float32), 'probabilities': array([0.9213713 , 0.07862872], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.1980886], dtype=float32), 'logistic': array([0.23181541], dtype=float32), 'probabilities': array([0.7681846 , 0.23181541], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.91622686], dtype=float32), 'logistic': array([0.28572732], dtype=float32), 'probabilities': array([0.7142727 , 0.28572732], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-5.374859], dtype=float32), 'logistic': array([0.00461022], dtype=float32), 'probabilities': array([0.9953898 , 0.00461022], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.9485486], dtype=float32), 'l

{'logits': array([-3.4925528], dtype=float32), 'logistic': array([0.02952487], dtype=float32), 'probabilities': array([0.9704751 , 0.02952487], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-4.0783443], dtype=float32), 'logistic': array([0.01665345], dtype=float32), 'probabilities': array([0.9833466 , 0.01665345], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.3726313], dtype=float32), 'logistic': array([0.08528364], dtype=float32), 'probabilities': array([0.9147164 , 0.08528365], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.9554863], dtype=float32), 'logistic': array([0.01878955], dtype=float32), 'probabilities': array([0.98121053, 0.01878955], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.957206], dtype=float32), 'l

{'logits': array([-0.801047], dtype=float32), 'logistic': array([0.30980158], dtype=float32), 'probabilities': array([0.6901984 , 0.30980158], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.7117417], dtype=float32), 'logistic': array([0.15293795], dtype=float32), 'probabilities': array([0.84706205, 0.15293795], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-4.164235], dtype=float32), 'logistic': array([0.01530375], dtype=float32), 'probabilities': array([0.98469627, 0.01530375], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.71946573], dtype=float32), 'logistic': array([0.32751065], dtype=float32), 'probabilities': array([0.67248935, 0.32751065], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.7830931], dtype=float32), 'l

{'logits': array([-2.720447], dtype=float32), 'logistic': array([0.06177755], dtype=float32), 'probabilities': array([0.9382225 , 0.06177755], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.1747042], dtype=float32), 'logistic': array([0.23600572], dtype=float32), 'probabilities': array([0.7639943 , 0.23600574], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.57345784], dtype=float32), 'logistic': array([0.36043933], dtype=float32), 'probabilities': array([0.63956064, 0.3604393 ], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.169929], dtype=float32), 'logistic': array([0.10248356], dtype=float32), 'probabilities': array([0.8975165 , 0.10248357], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.99399805], dtype=float32), '

{'logits': array([-3.4859252], dtype=float32), 'logistic': array([0.02971537], dtype=float32), 'probabilities': array([0.9702846 , 0.02971536], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.7870238], dtype=float32), 'logistic': array([0.05802943], dtype=float32), 'probabilities': array([0.9419706 , 0.05802943], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.471936], dtype=float32), 'logistic': array([0.18664853], dtype=float32), 'probabilities': array([0.8133515 , 0.18664855], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-5.229873], dtype=float32), 'logistic': array([0.00532569], dtype=float32), 'probabilities': array([0.9946743 , 0.00532569], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.5976003], dtype=float32), 'lo

{'logits': array([-3.438388], dtype=float32), 'logistic': array([0.03111705], dtype=float32), 'probabilities': array([0.968883  , 0.03111705], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.3804739], dtype=float32), 'logistic': array([0.08467383], dtype=float32), 'probabilities': array([0.9153261 , 0.08467383], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.7146642], dtype=float32), 'logistic': array([0.15255973], dtype=float32), 'probabilities': array([0.8474403 , 0.15255973], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.7061162], dtype=float32), 'logistic': array([0.33045757], dtype=float32), 'probabilities': array([0.66954243, 0.3304576 ], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-3.4801147], dtype=float32), 'l

{'logits': array([-3.112748], dtype=float32), 'logistic': array([0.04258447], dtype=float32), 'probabilities': array([0.9574156 , 0.04258447], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([0.00715899], dtype=float32), 'logistic': array([0.50178975], dtype=float32), 'probabilities': array([0.49821028, 0.50178975], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([0.9589908], dtype=float32), 'logistic': array([0.72291964], dtype=float32), 'probabilities': array([0.27708027, 0.72291964], dtype=float32), 'class_ids': array([1], dtype=int64), 'classes': array([b'1'], dtype=object)}
{'logits': array([-0.79311776], dtype=float32), 'logistic': array([0.31149963], dtype=float32), 'probabilities': array([0.68850034, 0.31149963], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-2.080829], dtype=float32), 'lo

{'logits': array([-2.6382523], dtype=float32), 'logistic': array([0.06671678], dtype=float32), 'probabilities': array([0.9332832 , 0.06671678], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.2808793], dtype=float32), 'logistic': array([0.2174006], dtype=float32), 'probabilities': array([0.7825994 , 0.21740058], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-0.0990572], dtype=float32), 'logistic': array([0.47525597], dtype=float32), 'probabilities': array([0.5247441 , 0.47525597], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([-1.4437022], dtype=float32), 'logistic': array([0.1909727], dtype=float32), 'probabilities': array([0.80902725, 0.19097267], dtype=float32), 'class_ids': array([0], dtype=int64), 'classes': array([b'0'], dtype=object)}
{'logits': array([12.020106], dtype=float32), 'log

** Create a list of only the class_ids key values from the prediction list of dictionaries, these are the predictions you will use to compare against the real y_test values. **

In [68]:
class_ids = []
for pred in predicts:
    class_ids.append(pred['class_ids'][0])

In [69]:
class_ids

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,


** Import classification_report from sklearn.metrics and then see if you can figure out how to use it to easily get a full report of your model's performance on the test data. **

In [70]:
from sklearn import metrics

In [71]:
Y = class_ids
class_names = ['0', '1']
print(metrics.classification_report(Y_test, Y, target_names=class_names))

             precision    recall  f1-score   support

          0       0.84      0.95      0.89      7436
          1       0.74      0.44      0.55      2333

avg / total       0.82      0.83      0.81      9769



# Great Job!