# Image classification 
Task: Predict disease type of bean plants.
Dataset: Beans from TFDS

## 1. Loading dataset

In [56]:
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf

In [76]:
(ds_train, ds_test), ds_info = tfds.load('beans', split = ['train', 'test[:20%]'], shuffle_files=True, as_supervised=True, with_info=True)

## 2. Preprocessing dataset
Includes: 
- resizing
- rotation
- batching

In [77]:
IMG_SIZE = 256
batch_size = 32
def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label

def resize_rescale(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = tf.reshape(image, [-1])
    return image, label

In [78]:
train = (
    ds_train
    .shuffle(2000)
    .map(resize_rescale, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

test = (
    ds_test
    .shuffle(2000)
    .map(resize_rescale, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)