In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow_addons.losses import TripletSemiHardLoss, TripletHardLoss
from beeid2.data_utils import filename2image
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt

TEST_CSV = "/home/jchan/beeid/notebooks/cmc_experiments/data/test.csv"

batch_size=32

def to_np_array(values, dim=128):
    return np.concatenate(list(values)).reshape(-1, dim)

def get_shortterm_mean_dist(test_df):    
    gtracks = test_df.groupby("global_track_id").filter(lambda x: len(x) > 1)
    means = list()
    for tid in gtracks.global_track_id.unique():
        track_embs = gtracks[gtracks.global_track_id == tid].emb.values
        track_embs = to_np_array(track_embs)
        dmatrix = 1 - tf.matmul(track_embs, track_embs.T).numpy()
        distances = dmatrix[np.triu_indices(dmatrix.shape[0], k = 1)].mean()
        means.append(distances)
    return np.mean(means), means

def get_longterm_mean_dist(test_df):
    timegap=15
    timegap_unit="m"
    gtracks = test_df.groupby("track_tag_id").filter(lambda x: len(x["global_track_id"].unique()) > 1)
    gtracks = gtracks.global_track_id.unique()

    distances = []
    eval_tracks = len(gtracks)
    print("Evaluating {} tracks.".format(eval_tracks))

    queries_num = 0
    tag_id_dist = defaultdict(lambda: np.array([]))

    for gtrack in tqdm(gtracks):
        is_same_track = (test_df.global_track_id == gtrack)
        im_tracks = test_df[is_same_track]
        query_row = im_tracks.iloc[0]
        is_same_id = (query_row.track_tag_id == test_df.track_tag_id)
        is_enough_timegap = np.abs(test_df.datetime2 - query_row.datetime2).astype('timedelta64[{}]'.format(timegap_unit)) > timegap

        gallery_df = test_df[(is_enough_timegap & is_same_id & ~ is_same_track)]
        if np.sum(gallery_df.track_tag_id == query_row.track_tag_id) == 0:
            continue

        tag_id = query_row.track_tag_id
        gtrack_dist = np.array([])

        gallery = to_np_array(gallery_df["emb"].values)
        for _, row in im_tracks.iterrows():
            query_id = row.track_tag_id
            query = np.expand_dims(row.emb, axis=0)
            distances = 1 - tf.matmul(query, gallery.T).numpy()
            distances = np.squeeze(distances)
            gtrack_dist = np.append(gtrack_dist, distances)
        tag_id_dist[tag_id] = np.append(tag_id_dist[tag_id],  gtrack_dist)

    mean_per_tag_id = [np.mean(ds) for ds in tag_id_dist.values()]
    return np.mean(mean_per_tag_id), mean_per_tag_id


def get_notsame_mean_dist(test_df):
    timegap=15
    timegap_unit="m"
    gtracks = test_df.groupby("track_tag_id").filter(lambda x: len(x["global_track_id"].unique()) > 1)
    gtracks = gtracks.global_track_id.unique()

    distances = []
    eval_tracks = len(gtracks)
    print("Evaluating {} tracks.".format(eval_tracks))

    queries_num = 0
    tag_id_dist = defaultdict(lambda: np.array([]))

    for gtrack in tqdm(gtracks):
        is_same_track = (test_df.global_track_id == gtrack)
        im_tracks = test_df[is_same_track]
        query_row = im_tracks.iloc[0]
        is_same_id = (query_row.track_tag_id == test_df.track_tag_id)
        is_enough_timegap = np.abs(test_df.datetime2 - query_row.datetime2).astype('timedelta64[{}]'.format(timegap_unit)) > timegap

        gallery_df = test_df[~ is_same_id]

        tag_id = query_row.track_tag_id
        gtrack_dist = np.array([])

        gallery = to_np_array(gallery_df["emb"].values)
        for _, row in im_tracks.iterrows():
            query_id = row.track_tag_id
            query = np.expand_dims(row.emb, axis=0)
            distances = 1 - tf.matmul(query, gallery.T).numpy()
            distances = np.squeeze(distances)
            gtrack_dist = np.append(gtrack_dist, distances)
        tag_id_dist[tag_id] = np.append(tag_id_dist[tag_id],  gtrack_dist)

    mean_per_tag_id = [np.mean(ds) for ds in tag_id_dist.values()]
    return np.mean(mean_per_tag_id), mean_per_tag_id

def eval_model_short_long_term(model_path):
    model = load_model(model_path, custom_objects={'tf': tf})

    test_df = pd.read_csv(TEST_CSV)
    test_df["datetime2"] = pd.to_datetime(test_df["datetime"])

    filenames = test_df["filename"].values
    images = filename2image(filenames)
    predictions = model.predict(images.batch(batch_size), verbose=True)
    test_df["emb"]  = list(predictions)

    shortterm_mean_dist, short_distribution = get_shortterm_mean_dist(test_df)
    longterm_mean_dist, long_distribution = get_longterm_mean_dist(test_df)
    notsame_mean_dist, notsame_distribution = get_notsame_mean_dist(test_df)
    return shortterm_mean_dist, short_distribution, longterm_mean_dist, long_distribution, notsame_mean_dist, notsame_distribution

In [None]:
benchmark = defaultdict(dict)
ntracks = [181, 362, 724, 1448, 2896, 4949]
for ntrack in ntracks:
    model_path = "/home/jchan/beeid/notebooks/cmc_experiments/models3/211006{:04}_untagged_augmentation_simplecnnv2_convb3_dim_128/model.tf".format(ntrack)
    shortterm_mean_dist, short_distribution, longterm_mean_dist, long_distribution, notsame_mean_dist, notsame_distribution = eval_model_short_long_term(model_path)

    benchmark[ntrack]["shortterm_mean"] = shortterm_mean_dist
    benchmark[ntrack]["longterm_mean"] = longterm_mean_dist
    benchmark[ntrack]["shortterm_hist"] = short_distribution
    benchmark[ntrack]["longterm_hist"] = long_distribution
    benchmark[ntrack]["notsame_mean"] = notsame_mean_dist
    benchmark[ntrack]["notsame_hist"] = notsame_distribution


In [None]:
for ntrack in ntracks:
    print(ntrack, benchmark[ntrack]['shortterm_mean'], benchmark[ntrack]['longterm_mean'])

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(16, 8))
axes = ax.ravel()
for i, ntrack in enumerate(ntracks):
    axes[i].set_title("ntracks: {}".format(ntrack))
    axes[i].hist(benchmark[ntrack]["shortterm_hist"], label="shortterm", color="tab:blue", density=True, alpha=0.7)
    axes[i].axvline(benchmark[ntrack]["shortterm_mean"], color="tab:blue", label="shortterm")
    
    axes[i].hist(benchmark[ntrack]["longterm_hist"], label="longterm", color="tab:orange", density=True, alpha=0.7)
    axes[i].axvline(benchmark[ntrack]["longterm_mean"], color="tab:orange", label="longterm")
    
    axes[i].hist(benchmark[ntrack]["notsame_hist"], label="notsame", color="tab:red", density=True, alpha=0.7)
    axes[i].axvline(benchmark[ntrack]["notsame_mean"], color="tab:red", label="notsame")
plt.tight_layout()
plt.legend();
