In [1]:
import argparse
import cv2
import os
import platform
import sys
from collections import defaultdict

import ipywidgets as widgets
from IPython.display import display
from IPython.core.debugger import set_trace
 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
 
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
import tensorflow_probability as tfp
from tqdm import tqdm
%matplotlib inline
 
print("OS Type: %s" % os.name)
print("OS Name: %s" % platform.system())
print("OS Release: %s" % platform.release())
print(f'Using Python={sys.version}')
print(f'Using Tensorflow={tf.__version__}')
print(f'Using Keras={keras.__version__}')
print("GPU Available: ", tf.test.is_gpu_available())



OS Type: posix
OS Name: Linux
OS Release: 4.15.0-50-generic
Using Python=3.7.2 (default, Mar 30 2019, 15:56:42) 
[GCC 5.4.0 20160609]
Using Tensorflow=2.0.0-rc2
Using Keras=2.2.4-tf
GPU Available:  True


In [2]:
print('Folders in current directory (for reference):')
for f in os.listdir('.'):
    if os.path.isdir(f) and not f.startswith('.'):
        print(f'   ./{f}')

Folders in current directory (for reference):
   ./train_imgs


In [3]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.applications import Xception
from tensorflow.keras.applications import InceptionResNetV2

from tensorflow.keras.applications import imagenet_utils
from tensorflow.keras.applications.inception_v3 import preprocess_input

from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.image import load_img

IMG_FEATURE_EXTRACTORS = {
    'inception': InceptionV3,
    'xception': Xception,
    'resnet': ResNet50,
    'inceptresnet2': InceptionResNetV2
}

## Key Training Parameters

In [9]:
train_folder = './train_imgs'
fex_model_name = 'xception'
fex_model_wgts = 'imagenet'
triplet_loss_dims =


In [10]:
#InceptionResNetV2(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)
imagenet_utils
fex_model = IMG_FEATURE_EXTRACTORS[fex_model_name](weights=fex_model_wgts, include_top=False)
print(f'Model {fex_model_name} has {len(fex_model.layers)} layers')
fex_model.summary()

Model xception has 132 layers
Model: "xception"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, None, None, 3 864         input_2[0][0]                    
__________________________________________________________________________________________________
block1_conv1_bn (BatchNormaliza (None, None, None, 3 128         block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_conv1_act (Activation)   (None, None, None, 3 0           block1_conv1_bn[0][0]            
_____________________________________________________________

In [11]:
print(f'Input: {fex_model.input_shape}')
print(f'Output: {fex_model.output_shape}')

Input: (None, None, None, 3)
Output: (None, None, None, 2048)


In [22]:
os.path.splitext('abc/xyz/hello.jpg')

('abc/xyz/hello', '.jpg')

In [25]:
raw_filenames = defaultdict(lambda: [])
countries = set([])
total_img_counts = {}

for root,subs,files in os.walk(train_folder):
    country = os.path.split(root.lstrip(train_folder))[1]
    if len(country.strip()) == 0:
        continue
        
    for f in files:
        if not os.path.splitext(f)[1].lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
            continue
            
        raw_filenames[country].append(os.path.join(root, f))
        countries.add(country)
    
total_imgs = 0
for country in sorted(countries):
    total_img_counts[country] = len(raw_filenames[country])
    total_imgs += len(raw_filenames[country])
    # Make sure we get the exact same ordering every time
    raw_filenames[country] = sorted(raw_filenames[country])
    print(f'Country {country.ljust(30)} has {total_img_counts[country]} images')
    
    '''
    for i,f in enumerate(raw_filenames[country]):
        print(f'   {f}')
        if i >= 2:
            print(f'   ...')
            break
    '''
print(f'Total {total_imgs} images across all countries')

Country Armenia                        has 11 images
Country Australia                      has 35 images
Country Germany                        has 107 images
Country Hungary+Slovakia+Croatia       has 49 images
Country Indonesia-Bali                 has 45 images
Country Japan                          has 62 images
Country Malaysia+Indonesia             has 55 images
Country Portugal+Brazil                has 54 images
Country Russia                         has 124 images
Country Spain                          has 68 images
Country Thailand                       has 104 images
Total 714 images across all countries


In [None]:
# We're going to need to do K-fold cross-validation, use sklearn utilities
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=31415926)
for train_indices, test_indices in kf.split(range(11)):
    print(train_indices, test_indices)

In [None]:
for i,spl in enumerate(kf.split(range(11))):
    print(i, spl)

In [None]:
    
def distributed_train_test_split(shards=5, shard_test_index=-1):
    for c in countries:
        lists = 
        

In [None]:
tf.data.Dataset