load data and make dataset

In [None]:
import tensorflow as tf
import os
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"]="0"
gpus = tf.config.experimental.list_physical_devices('GPU')
train_dir = '../dataset/classification/train'
valid_dir = '../dataset/classification/valid'
batch_size = 64
crop_size = 224

train_dataset = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    seed=1,
    image_size=(crop_size, crop_size),
    batch_size=batch_size,
    shuffle=True,
)

valid_dataset = tf.keras.utils.image_dataset_from_directory(
  valid_dir,
  seed=1,
  image_size=(crop_size, crop_size),
  batch_size=batch_size,
)

def _normalize_img(img, label):
    img = tf.keras.applications.resnet_v2.preprocess_input(img)
    return (img, label)

train_dataset = train_dataset.map(_normalize_img)
valid_dataset = valid_dataset.map(_normalize_img)

cnt_1,cnt_0 = 0,0
for batch in train_dataset:
    cnt_1 = cnt_1 + np.sum(batch[1]==1)
    cnt_0 = cnt_0 + np.sum(batch[1]==0)
print('train:  positive ',cnt_1,' negative ',cnt_0)

cnt_1,cnt_0 = 0,0
for batch in valid_dataset:
    cnt_1 = cnt_1 + np.sum(batch[1]==1)
    cnt_0 = cnt_0 + np.sum(batch[1]==0)
print('valid:  positive ',cnt_1,' negative ',cnt_0)

In [None]:
def get_dataset(dataset):
    X=tf.zeros([0,2048])
    Y=tf.zeros([0],dtype=tf.int32)
    for sample in dataset:
        x = model.predict(sample[0])
        y = sample[1]
        X = tf.concat([X,x],0)
        Y = tf.concat([Y,y],0)
    return X,Y

# KNN

from sklearn.neighbors import KNeighborsClassifier

def measure(model):
    
    X_train,Y_train = get_dataset(train_dataset)
    X_valid,Y_valid = get_dataset(valid_dataset)

    valid_K, valid_acc = 0,0

    K=1
        
    correct = 0
    sum = 0
    neigh = KNeighborsClassifier(n_neighbors=K)
    neigh.fit(X_train, Y_train)

    acc = neigh.score(X_valid,Y_valid)
    
    valid_acc = acc
    valid_K = K
            
    return valid_K, valid_acc


In [None]:
import io
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from sklearn.cluster import KMeans
import tensorflow_hub as hub

origin_resnet = keras.applications.ResNet50V2(weights='imagenet')
resnet = tf.keras.Model(inputs=origin_resnet.input,outputs=origin_resnet.layers[-2].output)

model = tf.keras.Sequential([
    resnet,
    tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings
])

optimizer = tf.keras.optimizers.Adam(1e-4)
checkpoint_filepath = '/tmp/cluster'

def loss_fn(x, y, cluster_labels, training=True):
    
    embeddings = model(x, training=training)
    
    kmeans = KMeans(n_clusters=4).fit(embeddings) 
    cluster_labels = tf.constant(kmeans.labels_)  

    class_labels = tf.constant(y)
    
    batch = len(cluster_labels)

    L2 =  lambda x,y : tf.sqrt(tf.reduce_sum(tf.square(x-y),axis=-1))

    triplets=[]
    
    for i in range(batch):
        anchor = embeddings[i]
        
        cluster_same_id = tf.where(cluster_labels==cluster_labels[i])
        class_same_id = tf.where(class_labels==class_labels[i])
        candidate_id = tf.sets.intersection(tf.transpose(cluster_same_id,[1,0]), tf.transpose(class_same_id,[1,0])).values
        if(len(candidate_id)<=1):
            continue
        candidate_positives = tf.gather_nd(embeddings,tf.expand_dims(candidate_id,-1))
        dis = L2(anchor,candidate_positives)
        positive = candidate_positives[tf.argmax(dis)]
        
        
        class_diff_id = tf.where(class_labels!=class_labels[i])
        if(len(class_diff_id)==0):
            continue
        candidate_negatives = tf.gather_nd(embeddings,class_diff_id)
        dis = L2(anchor,candidate_negatives)
        negative = candidate_negatives[tf.argmin(dis)]
        
        
        ap = L2(anchor, positive)
        an = L2(anchor, negative)

        triplet = tf.maximum(ap - an + 0.01, 0.)
        
        
        triplets.append(triplet)
    
    return tf.reduce_mean(triplets), cluster_labels

def train_step(x, y,cluster_labels):
    
    with tf.GradientTape() as tape:
        loss, cluster_labels = loss_fn(x, y, cluster_labels,  training=True)
    
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    

    return loss, cluster_labels

def test_step(x, y, cluster_labels):
    loss, cluster_labels = loss_fn(x, y, cluster_labels,  training=False)
    return loss,  cluster_labels

In [None]:
epochs = 350
best_acc = 0.
for epoch in range(epochs):
    print('---------------------------------------------------------------------------------')
    # Iterate over the batches of the dataset.
    
    train_cluster_labels, test_cluster_labels = None, None
    
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        
        train_loss, train_cluster_labels = train_step(x_batch_train, y_batch_train, train_cluster_labels)

    # Run a validation loop at the end of each epoch.                                                                                                                       
    for step, (x_batch_test, y_batch_test) in enumerate(test_dataset):
        test_loss, test_cluster_labels = test_step(x_batch_test, y_batch_test, test_cluster_labels)
    
    
    K, acc = measure(model)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
    if acc>=best_acc:
        model.save_weights(checkpoint_filepath)
        best_acc = acc
     
    print("epoch: %d" % (epoch),"  train loss: %.4f" % (float(train_loss)), 
          "  test loss: %.4f" % (float(test_loss)),"  test acc: %.4f" % (float(acc)))

In [None]:
model.load_weights(checkpoint_filepath)
K, acc = measure(model)

print('acc: {}'.format(acc))