In [11]:
pip install tensorflow_addons

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (611 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/611.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.8/611.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m604.2/611.8 kB[0m [31m8.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m611.8/611.8 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard<3.0.0,>=2.7 (from tensorflow_addons)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, tensorflow_addons
Successfully installed tensorflow_addons-0.23.0 typeguard-2.13.3


In [8]:
pip install vit-keras

Collecting vit-keras
  Downloading vit_keras-0.1.2-py3-none-any.whl (24 kB)
Collecting validators (from vit-keras)
  Downloading validators-0.22.0-py3-none-any.whl (26 kB)
Installing collected packages: validators, vit-keras
Successfully installed validators-0.22.0 vit-keras-0.1.2


In [29]:
import time
import matplotlib.pyplot as plt
from google.colab import drive
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
#import tensorflow_addons as tfa
from vit_keras import vit

In [2]:
## For use Google Colab (mount Google Drive)
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
project_path="/content/drive/MyDrive/Academico/LeWagon/train_test"
train_dir = f"{project_path}/train"
test_dir = f"{project_path}/test"

In [40]:
"""
Result a generator with shape (150,150,3)
"""
# Create a generator with augmentation for training and validation:
dgen_train = ImageDataGenerator(rescale = 1./255,
                                    validation_split=0.2,
                                    shear_range=0.2,
                                    zoom_range = 0.2,
                                    horizontal_flip = False)

# Create a generator without augmentation for test:
dgen_test = ImageDataGenerator(rescale=1./255)

# Make generators by directories:
# The classes wiil be the subdirectories
train_generator = dgen_train.flow_from_directory(train_dir,
                                                 target_size=(224,224),
                                                 subset = "training",
                                                 batch_size = 32,
                                                 class_mode = "categorical")

validation_generator = dgen_train.flow_from_directory(train_dir,
                                                      target_size=(224,224),
                                                      subset = "validation",
                                                      batch_size = 32,
                                                      class_mode = "categorical")

test_generator = dgen_test.flow_from_directory(test_dir,
                                               target_size=(224,224),
                                               batch_size = 32,
                                               class_mode = "categorical")



Found 6798 images belonging to 4 classes.
Found 1698 images belonging to 4 classes.
Found 4867 images belonging to 4 classes.


In [None]:
train_generator.image_shape

In [6]:
def plot_history(history, title='', axs=None, exp_name=""):
    if axs is not None:
        ax1, ax2 = axs
    else:
        f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    if len(exp_name) > 0 and exp_name[0] != '_':
        exp_name = '_' + exp_name
    ax1.plot(history.history['loss'], label='train' + exp_name)
    ax1.plot(history.history['val_loss'], label='validation' + exp_name)
    #ax1.set_ylim(0., 2.2)
    ax1.set_title('loss')
    ax1.legend()

    ax2.plot(history.history['accuracy'], label='train'  + exp_name)
    ax2.plot(history.history['val_accuracy'], label='validation'  + exp_name)
    #ax2.set_ylim(0.25, 1.)
    ax2.set_title('Accuracy')
    ax2.legend()
    return (ax1, ax2)

In [None]:
vit.vit_b16()

In [41]:
#preproc_layer = preprocvit.vit_b16ess_input(train_generator)
vit_model = vit.vit_b16(image_size=224,
                         activation='relu',
                         pretrained=True,
                         include_top=True,
                         pretrained_top=False,
                         classes=4)

for layer in vit_model.layers:
    layer.trainable = False

model = Sequential()
model.add(vit_model)
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(4, activation='softmax'))

model.summary()



Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vit-b16 (Functional)        (None, 4)                 85801732  
                                                                 
 flatten_2 (Flatten)         (None, 4)                 0         
                                                                 
 dense_4 (Dense)             (None, 32)                160       
                                                                 
 dropout_2 (Dropout)         (None, 32)                0         
                                                                 
 dense_5 (Dense)             (None, 4)                 132       
                                                                 
Total params: 85802024 (327.31 MB)
Trainable params: 292 (1.14 KB)
Non-trainable params: 85801732 (327.31 MB)
_________________________________________________________________


In [42]:
model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(learning_rate=1e-4),
              metrics=['accuracy'])

In [38]:
es = EarlyStopping(monitor = 'val_loss',
                   patience = 5,
                   verbose = 1,
                   restore_best_weights = True)

In [43]:
start = time.perf_counter()
history = model.fit(train_generator,
                        batch_size = 32,
                        validation_data=validation_generator,
                        epochs = 20,
                        callbacks=[es])
end = time.perf_counter()
print(f"\n✅ Total time: ({round(end - start, 2)}s)")

Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 

In [None]:
plot_history(history);

In [None]:
model.evaluate(test_generator)