In [1]:
%config Completer.use_jedi = False

In [2]:
import numpy as np
import os
import tensorflow as tf
import json
from tqdm import tqdm
import random
from scipy.spatial.distance import cdist
from pathlib import Path
from multiprocessing import Pool

In [3]:
def config_gpus():
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                logical_gpus = tf.config.list_logical_devices('GPU')
                print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)
    return

In [4]:
def parse_fn(example_proto):
    features = {"data": tf.io.FixedLenFeature((), tf.string),
                "label": tf.io.FixedLenFeature((), tf.int64),
                }
    parsed_features = tf.io.parse_single_example(example_proto, features)
    data = tf.io.decode_raw(parsed_features["data"], tf.float32)
    data = tf.reshape(data, shape=[100, 5])
    return data, parsed_features["label"]


def load_tfrecord(filepath, batch_size=128, shuffle=True):
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    dataset = tf.data.TFRecordDataset(filepath, num_parallel_reads=4)
    dataset = dataset.map(parse_fn, num_parallel_calls=4)
    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=100)
    return dataset.prefetch(AUTOTUNE)

In [5]:
def evaluate(model_path, data_path, metric, threshold):
    
    model = tf.keras.models.load_model(model_path)
    
    test_dataset = load_tfrecord(filepath=list(map(str, Path(data_path).glob("**/*.tfrecord"))), batch_size=600, shuffle=False)
    
    for batch in tqdm(test_dataset):
        x_test, y_test = batch
        embeds = model(x_test, training=False).numpy()
        y_test = y_test.numpy()
        
        for y in np.unique(y_test):
                        
            positive_embeds = embeds[y_test == y]
            negative_embeds = embeds[y_test != y]
            
            positive_dist = cdist(positive_embeds, positive_embeds, metric).round(4)
            negative_dist = cdist(positive_embeds, negative_embeds, metric).round(4)

            positive_dist = np.reshape(positive_dist, (-1,))
            positive_dist = positive_dist[positive_dist >= 1e-6]
            negative_dist = np.reshape(negative_dist, (-1,))
            
            if len(positive_dist) < 2 or len(negative_dist) < 2:
                continue
            
            fn = np.sum(negative_dist < threshold) / len(negative_dist) 
            fp = np.sum(positive_dist > threshold) / len(positive_dist)

            log = {"id": int(y),
                   "fp": float(np.mean(fp) * 100),
                   "fn": float(np.mean(fn) * 100)}

            dist_log = {"id": int(y),
                        "d_positive": positive_dist.tolist(),
                        "d_negative": negative_dist.tolist()}

            with open("eval_{}_threshold_{}.json".format(metric, threshold), "a") as f:
                f.write(json.dumps(log)+"\n")

#             with open("eval_{}_distance.json".format(metric), "a") as f:
#                 f.write(json.dumps(dist_log)+"\n")
        
        print(log)
                
    return

In [6]:
evaluate(model_path="ckpt/2022-06-03-18h16m56s/serving/", data_path="data/test/", metric="cosine", threshold=0.4)



1it [00:02,  2.09s/it]

{'id': 166022, 'fp': 10.0, 'fn': 0.10084033613445378}


2it [00:02,  1.00s/it]

{'id': 166045, 'fp': 0.0, 'fn': 0.0}


3it [00:02,  1.52it/s]

{'id': 166069, 'fp': 0.0, 'fn': 0.08361204013377926}


4it [00:02,  1.98it/s]

{'id': 166091, 'fp': 0.0, 'fn': 0.16806722689075632}


5it [00:03,  2.34it/s]

{'id': 166114, 'fp': 0.0, 'fn': 0.16863406408094433}


6it [00:03,  2.59it/s]

{'id': 166138, 'fp': 0.0, 'fn': 0.5852842809364548}


7it [00:03,  2.78it/s]

{'id': 166160, 'fp': 0.0, 'fn': 0.1445434834979523}


8it [00:04,  2.86it/s]

{'id': 166184, 'fp': 6.666666666666667, 'fn': 0.02805836139169473}


9it [00:04,  2.93it/s]

{'id': 166208, 'fp': 0.0, 'fn': 0.0}


10it [00:04,  2.96it/s]

{'id': 166232, 'fp': 0.0, 'fn': 0.04194630872483222}


11it [00:05,  2.94it/s]

{'id': 166255, 'fp': 0.0, 'fn': 0.0}


12it [00:05,  2.89it/s]

{'id': 166278, 'fp': 4.761904761904762, 'fn': 0.2890869669959046}


13it [00:05,  2.85it/s]

{'id': 166301, 'fp': 13.333333333333334, 'fn': 0.16835016835016833}


14it [00:06,  2.86it/s]

{'id': 166324, 'fp': 0.0, 'fn': 0.33500837520938026}


15it [00:06,  2.93it/s]

{'id': 166348, 'fp': 0.0, 'fn': 0.0}


16it [00:06,  2.95it/s]

{'id': 166371, 'fp': 0.0, 'fn': 0.06722689075630252}


17it [00:07,  2.95it/s]

{'id': 166394, 'fp': 0.0, 'fn': 0.0}


18it [00:07,  2.97it/s]

{'id': 166416, 'fp': 0.0, 'fn': 0.0}


19it [00:07,  2.95it/s]

{'id': 166441, 'fp': 0.0, 'fn': 0.0}


20it [00:08,  2.92it/s]

{'id': 166463, 'fp': 4.761904761904762, 'fn': 0.6504456757407854}


21it [00:08,  2.89it/s]

{'id': 166487, 'fp': 0.0, 'fn': 0.27917364600781686}


22it [00:08,  2.86it/s]

{'id': 166510, 'fp': 19.047619047619047, 'fn': 0.0481811611659841}


23it [00:09,  2.86it/s]

{'id': 166532, 'fp': 0.0, 'fn': 0.0}


24it [00:09,  2.80it/s]

{'id': 166557, 'fp': 16.666666666666664, 'fn': 0.25167785234899326}


25it [00:09,  2.77it/s]

{'id': 166580, 'fp': 0.0, 'fn': 0.6141820212171971}


26it [00:10,  2.70it/s]

{'id': 166603, 'fp': 0.0, 'fn': 0.0}


27it [00:10,  2.67it/s]

{'id': 166626, 'fp': 0.0, 'fn': 0.0}


28it [00:11,  2.65it/s]

{'id': 166649, 'fp': 30.0, 'fn': 0.0}


29it [00:11,  2.60it/s]

{'id': 166674, 'fp': 13.333333333333334, 'fn': 0.25252525252525254}


30it [00:11,  2.56it/s]

{'id': 166698, 'fp': 0.0, 'fn': 0.0}


31it [00:12,  2.51it/s]

{'id': 166721, 'fp': 0.0, 'fn': 0.0}


32it [00:12,  2.48it/s]

{'id': 166744, 'fp': 0.0, 'fn': 0.0}


33it [00:13,  2.46it/s]

{'id': 166768, 'fp': 0.0, 'fn': 0.08389261744966443}


34it [00:13,  2.44it/s]

{'id': 166791, 'fp': 0.0, 'fn': 0.02805836139169473}


35it [00:14,  2.41it/s]

{'id': 166814, 'fp': 0.0, 'fn': 0.16750418760469013}


36it [00:14,  2.38it/s]

{'id': 166836, 'fp': 0.0, 'fn': 0.06722689075630252}


37it [00:14,  2.37it/s]

{'id': 166858, 'fp': 23.809523809523807, 'fn': 0.02409058058299205}


38it [00:15,  2.35it/s]

{'id': 166881, 'fp': 0.0, 'fn': 0.9154420621536979}


39it [00:15,  2.26it/s]

{'id': 166905, 'fp': 0.0, 'fn': 0.16778523489932887}


40it [00:16,  2.24it/s]

{'id': 166928, 'fp': 13.333333333333334, 'fn': 0.16835016835016833}


41it [00:16,  2.21it/s]

{'id': 166953, 'fp': 0.0, 'fn': 0.6141820212171971}


42it [00:17,  2.20it/s]

{'id': 166976, 'fp': 0.0, 'fn': 0.0}


43it [00:17,  2.18it/s]

{'id': 166999, 'fp': 0.0, 'fn': 0.03361344537815126}


44it [00:18,  2.17it/s]

{'id': 167006, 'fp': 4.761904761904762, 'fn': 0.02409058058299205}


45it [00:18,  2.14it/s]

{'id': 167030, 'fp': 0.0, 'fn': 0.0}


46it [00:19,  2.13it/s]

{'id': 167053, 'fp': 0.0, 'fn': 0.1964085297418631}


47it [00:19,  2.10it/s]

{'id': 167078, 'fp': 0.0, 'fn': 0.0}


48it [00:20,  2.08it/s]

{'id': 167100, 'fp': 0.0, 'fn': 0.08361204013377926}


49it [00:20,  2.05it/s]

{'id': 167121, 'fp': 0.0, 'fn': 0.11223344556677892}


50it [00:21,  2.03it/s]

{'id': 167145, 'fp': 0.0, 'fn': 0.11166945840312675}


51it [00:21,  2.03it/s]

{'id': 167167, 'fp': 0.0, 'fn': 0.08417508417508417}


52it [00:22,  2.01it/s]

{'id': 167190, 'fp': 0.0, 'fn': 0.0}


53it [00:22,  1.99it/s]

{'id': 167212, 'fp': 0.0, 'fn': 0.8978675645342313}


54it [00:23,  1.98it/s]

{'id': 167236, 'fp': 13.333333333333334, 'fn': 0.0}


55it [00:23,  1.96it/s]

{'id': 167260, 'fp': 0.0, 'fn': 0.30864197530864196}


56it [00:24,  1.94it/s]

{'id': 167283, 'fp': 0.0, 'fn': 0.0}


57it [00:24,  1.93it/s]

{'id': 167307, 'fp': 33.33333333333333, 'fn': 0.0}


58it [00:25,  1.92it/s]

{'id': 167329, 'fp': 0.0, 'fn': 0.0}


59it [00:25,  1.92it/s]

{'id': 167352, 'fp': 0.0, 'fn': 0.10084033613445378}


60it [00:26,  1.90it/s]

{'id': 167375, 'fp': 0.0, 'fn': 0.16722408026755853}


61it [00:26,  1.90it/s]

{'id': 167397, 'fp': 0.0, 'fn': 0.02409058058299205}


62it [00:27,  1.89it/s]

{'id': 167421, 'fp': 50.0, 'fn': 0.12583892617449663}


63it [00:27,  1.89it/s]

{'id': 167444, 'fp': 0.0, 'fn': 0.13445378151260504}


64it [00:28,  1.90it/s]

{'id': 167467, 'fp': 0.0, 'fn': 0.0}


65it [00:28,  1.90it/s]

{'id': 167489, 'fp': 0.0, 'fn': 0.08361204013377926}


66it [00:29,  1.90it/s]

{'id': 167512, 'fp': 0.0, 'fn': 0.14029180695847362}


67it [00:29,  1.90it/s]

{'id': 167535, 'fp': 0.0, 'fn': 0.08389261744966443}


68it [00:30,  1.88it/s]

{'id': 167558, 'fp': 0.0, 'fn': 0.0}


69it [00:31,  1.87it/s]

{'id': 167582, 'fp': 0.0, 'fn': 0.919732441471572}


70it [00:31,  1.87it/s]

{'id': 167604, 'fp': 0.0, 'fn': 0.0}


71it [00:32,  1.86it/s]

{'id': 167627, 'fp': 0.0, 'fn': 1.1764705882352942}


72it [00:32,  1.87it/s]

{'id': 167651, 'fp': 0.0, 'fn': 0.16863406408094433}


73it [00:33,  1.88it/s]

{'id': 167674, 'fp': 23.809523809523807, 'fn': 0.02409058058299205}


74it [00:33,  1.89it/s]

{'id': 167698, 'fp': 10.0, 'fn': 0.03361344537815126}


75it [00:34,  1.89it/s]

{'id': 167721, 'fp': 0.0, 'fn': 0.08361204013377926}


76it [00:34,  1.89it/s]

{'id': 167745, 'fp': 0.0, 'fn': 0.0}


77it [00:35,  1.88it/s]

{'id': 167769, 'fp': 0.0, 'fn': 0.3908431044109436}


78it [00:35,  1.88it/s]

{'id': 167793, 'fp': 0.0, 'fn': 0.0963623223319682}


79it [00:36,  1.88it/s]

{'id': 167816, 'fp': 16.666666666666664, 'fn': 0.16778523489932887}


80it [00:36,  1.87it/s]

{'id': 167838, 'fp': 0.0, 'fn': 0.0}


81it [00:37,  1.86it/s]

{'id': 167862, 'fp': 4.761904761904762, 'fn': 0.0}


82it [00:37,  1.85it/s]

{'id': 167885, 'fp': 0.0, 'fn': 0.0}


83it [00:38,  1.85it/s]

{'id': 167908, 'fp': 0.0, 'fn': 0.16722408026755853}


84it [00:39,  1.87it/s]

{'id': 167930, 'fp': 19.047619047619047, 'fn': 0.0}


85it [00:39,  1.87it/s]

{'id': 167954, 'fp': 47.61904761904761, 'fn': 0.0}


86it [00:40,  1.89it/s]

{'id': 167978, 'fp': 9.523809523809524, 'fn': 0.481811611659841}


87it [00:40,  2.14it/s]

{'id': 167999, 'fp': 0.0, 'fn': 0.0}





In [15]:
import pandas as pd

with open("eval_cosine_threshold_0.4.json") as f:
    df = [json.loads(line) for line in f]
    
df = pd.DataFrame(df)
len(df)
df

Unnamed: 0,id,fp,fn
0,161000,9.523810,0.120453
1,161001,0.000000,0.000000
2,161002,0.000000,0.457721
3,161003,0.000000,0.000000
4,161004,14.285714,0.096362
...,...,...,...
8153,167995,0.000000,0.196541
8154,167996,0.000000,0.029621
8155,167997,39.285714,0.444313
8156,167998,0.000000,0.000000


In [16]:
df['fp'].mean(), df['fn'].mean()

(5.772894602877442, 0.14282834537919326)