<a href="https://www.kaggle.com/code/masoudmahdavii/classification-of-flags-by-transfer-learning?scriptVersionId=155183442" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# 0. Import libraries

In [None]:
import re
import cv2
import random
import shutil
import tensorflow as tf
import matplotlib.pyplot as plt 
from keras.models import load_model
from keras.applications import ResNet50
from keras.preprocessing.image import ImageDataGenerator


# 1. Plot countries flag

In [None]:
train_data = "/kaggle/input/country-flags/flags/train/"
test_data = "/kaggle/input/country-flags/flags/test/"
train_data_id = "/kaggle/input/country-flags-in-the-wild/verified_flags_train/"
countries = ['Afghanistan', 'Azerbaijan','Iran', 'Iraq', 'Pakistan', 'Turkey', 'Turkmenistan']
for i, country_name in enumerate(countries):
    country_path = os.path.join(train_data,country_name)
    country_first_im = os.listdir(country_path)[0]
    image = cv2.imread(os.path.join(country_path,country_first_im))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    plt.subplot(2, 4, i+1)
    plt.imshow(image)
    plt.xlabel(country_name)
    plt.xticks([])
    plt.yticks([])
plt.figure(figsize=(16, 10))
plt.show()

# 2. Setup Paths

In [None]:
def create_paths():
    global paths
    paths = {
            'Train_path': os.path.join('Data', 'Train'),
            'Test_path': os.path.join('Data', 'Test'),
        }

    for path in paths.values():
        #Create directories (if there aren't)
        if not os.path.exists('Data'+path):
            for country in countries:
                !mkdir -p {path+'/'+country}
create_paths()

# 3. Split data into train and test sets 

In [None]:
for i, country_name in enumerate(countries):
    country_path_train = os.path.join(train_data, country_name)
    country_path_test = os.path.join(test_data, country_name)
    country_im_list_train = os.listdir(country_path_train)
    country_im_list_test = os.listdir(country_path_test)
    
    train_srcpath_list = [country_path_train+'/'+x for x in country_im_list_train]
    train_dstpath_list = [paths['Train_path']+'/'+country_name+'/'+x for x in country_im_list_train]
    
    test_srcpath_list = [country_path_test+'/'+x for x in country_im_list_test] 
    test_dstpath_list = [paths['Test_path']+'/'+country_name+'/'+x for x in country_im_list_test]
    
    train_srcpath_list.extend(test_srcpath_list)
    train_dstpath_list.extend(test_dstpath_list)

    if os.path.exists(paths['Train_path']):
        #Copy the data from the source to two directories named Train_path and Test_path.
        list(map(lambda n1, n2:shutil.copyfile(n1, n2),train_srcpath_list, train_dstpath_list))
    else:
        print("There is't train path, please run this cell again") 
        !rm -rf Data
        create_path()
        break

# 4. Create image data generators

In [None]:
# Create our image preprocessors.
train_datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')
test_datagen = ImageDataGenerator(
        rescale=1.0/255,
        validation_split=0.5,
)

# 4.1 Distribute Train data

In [None]:
# Get the country labels
class_labels = list(train_generator.class_indices.keys())

# Calculate the count of images per country
class_counts = {label: 0 for label in class_labels}

for i in range(len(train_generator)):
    batch_data, batch_labels = train_generator[i]
    for j in range(len(batch_data)):
        class_idx = int(batch_labels[j].argmax())
        class_label = class_labels[class_idx]
        class_counts[class_label] += 1

# Define unique colors for each country
class_colors = plt.cm.tab20(np.linspace(0, 1, len(class_labels)))

# Create a bar chart with different colors for each country
plt.figure(figsize=(8, 4))
bars = plt.bar(class_counts.keys(), class_counts.values(), color=class_colors)
plt.ylabel('Count')
plt.title('Number of Images per Class')
plt.xticks(rotation=45)

# Add a legend for country colors
legend_labels = [plt.Line2D([0], [0], color=class_colors[i], lw=4, label=class_labels[i]) for i in range(len(class_labels))]
plt.legend(handles=legend_labels, title="Classes")


plt.show()

Distribute Test data

# 4.2 Distribute Test data

In [None]:
# Get the country labels
class_labels = list(test_generator.class_indices.keys())

# Calculate the count of images per country
class_counts = {label: 0 for label in class_labels}

for i in range(len(test_generator)):
    batch_data, batch_labels = test_generator[i]
    for j in range(len(batch_data)):
        class_idx = int(batch_labels[j].argmax())
        class_label = class_labels[class_idx]
        class_counts[class_label] += 1

# Define unique colors for each country
class_colors = plt.cm.tab20(np.linspace(0, 1, len(class_labels)))

# Create a bar chart with different colors for each country
plt.figure(figsize=(8, 4))
bars = plt.bar(class_counts.keys(), class_counts.values(), color=class_colors)
plt.ylabel('Count')
plt.title('Number of Images per Class')
plt.xticks(rotation=45)

# Add a legend for country colors
legend_labels = [plt.Line2D([0], [0], color=class_colors[i], lw=4, label=class_labels[i]) for i in range(len(class_labels))]
plt.legend(handles=legend_labels, title="Classes")
plt.show()


In [None]:
# Define model input size and batch size.
input_size = (224, 224)
batch_size = 16

In [None]:
train_generator = train_datagen.flow_from_directory(
    paths['Train_path'],             # Path to the training data
    target_size=input_size,          # Resize images to this size
    batch_size=batch_size,           # Number of images in each batch
    seed=32,                         # Optional: Set a random seed for shuffling
    shuffle=True,                    # Shuffle the data during training
    class_mode='categorical'         # Mode for class labels (categorical for one-hot encoding)
)
val_generator = test_datagen.flow_from_directory(
    paths['Test_path'],
    target_size=input_size,
    batch_size = batch_size,
    subset="training",
    class_mode='categorical')

test_generator = test_datagen.flow_from_directory(
    paths['Test_path'],
    target_size=input_size,
    batch_size = batch_size,
    subset="validation",
    class_mode='categorical')

# 5. Create classification model

Add a classification head to this pretrained model

In [None]:
def feature_extractor(inputs):
    #For feature extractor use ResNet50 model and retrain top layer from the scratch.
    weights = '../input/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
    feature_extractor = ResNet50(input_shape=input_size + (3,),
                                               include_top=False,
                                               weights=None)       # Set the weights to None so we can load from Kaggle
    feature_extractor.trainable = True                             # Don't want to freeze feature extractor weights
    feature_extractor.load_weights(weights,by_name=True) 
    return feature_extractor(inputs)

def classifier(inputs):
    #Defines final dense layers and subsequent softmax layer for classification.
    x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024, activation="relu")(x)
    x = tf.keras.layers.Dense(512, activation="relu")(x)
    x = tf.keras.layers.Dense(len(countries), activation="softmax", name="classification")(x)
    return x

def final_model(inputs):
    #resize = tf.keras.layers.UpSampling2D(size=(7,7))(inputs)    # Define resize if you want to resize input images
    resnet_feature_extractor = feature_extractor(inputs)
    classification_output = classifier(resnet_feature_extractor)

    return classification_output

def define_compile_model():
    #Let's define Opitmizer and Loss function
    inputs = tf.keras.layers.Input(shape=input_size + (3,))
    classification_output = final_model(inputs)
    model = tf.keras.Model(inputs=inputs, outputs = classification_output)
    model.compile(optimizer='SGD',
                  loss='categorical_crossentropy',
                  metrics = ['accuracy'])
    return model


model = define_compile_model()

model.summary()

In [None]:
image_batch, label_batch = next(iter(train_generator))
feature_batch = model(image_batch)
print(f"Output shape of feature extractor: {feature_batch.shape}")

In [None]:
def test_model(model, test_datagen):
    """
    The test_model function plots some flags from the data generator so that the model's progress can be seen
    
    model: The model created fro classification
    test_datagen: The generator of the data that is defined
    
    return: 8 flag images with actual labels and predicted labels
    """
    examples, labels= next(test_datagen)
    pred_y = model.predict(examples, verbose = False)
    class_pred = np.argmax(pred_y[0])
    plt.figure(figsize=(16, 10))
    for i, image in enumerate(examples):
        gt = countries[np.argmax(labels[i])]
        predict_name = countries[np.argmax(pred_y[i])]

        color = 'green' if gt == predict_name else 'red'
        plt.subplot(1, 8, i+1)
        plt.imshow(image)
        plt.xlabel('Pred: {}'.format(predict_name), color=color)
        plt.ylabel('label: {}'.format(gt), color=color)
        plt.xticks([])
        plt.yticks([])
        if i == 7:
            break
    plt.show()

def test(model):
    test_datagen = test_generator
    
    test_model(model, test_datagen)
test(model)

# 6. Define Callback

In [None]:
class ShowTestImages(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        #Every five epochs run the test function
        if epoch % 5 == 0:
            test(self.model)

# 7. Training the classification model

In [None]:
"""
#We could use earlystopping to restore weights if validation loss doesn't get better in five epochs
earlystopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss", 
    patience=5, 
    restore_best_weights=True
)
"""

history = model.fit(train_generator,
                    #steps_per_epoch=train_generator.samples // batch_size,
                    validation_data=val_generator, 
                    #validation_steps=val_generator.samples // batch_size,
                    epochs=50, 
                    batch_size=batch_size,
                    callbacks=[
              ShowTestImages()]
)

In [None]:
#plot the results of the training
plt.figure(figsize=(13,6))
for i, (name, value) in enumerate(history.history.items()):
    if re.search('loss', name):
        plt.subplot(2,2,1)
        plt.plot(history.history[name])
        plt.ylabel('loss')
    else:
        plt.subplot(2,2,2)
        plt.plot(history.history[name])
        plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
plt.show()

# 8. Evaluate the final model

In [None]:
loss, acc = model.evaluate(test_generator, verbose=True, batch_size=16)
print("Test accuracy: {:.2f}% \nTest loss: {:.4f}".format(acc*100, loss))


In [None]:
test(model)

# Confusion Matrix

In [None]:
"""
import itertools
from sklearn.metrics import classification_report, confusion_matrix

true_labels = test_generator.classes
predictions = model.predict(test_generator)
predicted_labels = np.argmax(predictions, axis=1)
cm = confusion_matrix(true_labels, predicted_labels)
class_names = test_generator.class_indices.keys()

def plot_confusion_matrix(cm, classes, normalize=False, title="Confusion Matrix", cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation="nearest", cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel("True label")
    plt.xlabel("Predicted label")

# Plot non-normalized confusion matrix
plt.figure(figsize=(8, 6))
plot_confusion_matrix(cm, classes=class_names, title="Confusion Matrix")
plt.show()"""