In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import os
import json
import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [3]:
dataset_name = "GEMM_EX_Wavenet"

In [4]:
timestamp = "20201209-094522"

In [5]:
model = keras.models.load_model("version/{}".format(timestamp))

In [6]:
model.summary()

Model: "wave_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
preprocessing/conv (Conv1D)  multiple                  1536      
_________________________________________________________________
residual_block (ResidualBloc multiple                  7936      
_________________________________________________________________
residual_block_1 (ResidualBl multiple                  7936      
_________________________________________________________________
residual_block_2 (ResidualBl multiple                  7936      
_________________________________________________________________
residual_block_3 (ResidualBl multiple                  7936      
_________________________________________________________________
residual_block_4 (ResidualBl multiple                  7936      
_________________________________________________________________
residual_block_5 (ResidualBl multiple                  793

In [7]:
param_list = dict()

with open("version/{}/dilations.json".format(timestamp), "r") as j:
    param_list["DILATIONS"] = json.load(j)["DILATIONS"]
param_list["FILTER_WIDTH"] = 2                          # == kernel_size
param_list["INITIAL_FILTER_WIDTH"] = 32
param_list["RECEPTIVE_FIELD"] = (param_list["FILTER_WIDTH"] - 1) * sum(param_list["DILATIONS"]) + param_list["INITIAL_FILTER_WIDTH"]
param_list["OUT_CHANNELS"] = 33476    #vocab_size
param_list["BATCH_SIZE"] = 8

In [8]:
test_set = pd.read_csv("data/{}_test_set.csv".format(dataset_name), dtype=np.float32)

In [9]:
x_test = tf.data.Dataset.from_tensor_slices(test_set[:-1]).window(param_list["RECEPTIVE_FIELD"], 1, 1, True)
x_test = x_test.flat_map(lambda x: x.batch(param_list["RECEPTIVE_FIELD"])) 
x_test = x_test.batch(param_list["BATCH_SIZE"])

In [10]:
y_test = tf.data.Dataset.from_tensor_slices(test_set[param_list["RECEPTIVE_FIELD"]:]['gpa'].astype(np.int32)).window(1, 1, 1, True)
y_test_slices = y_test.flat_map(lambda y: y.batch(1))
y_test = y_test_slices.map(lambda y: tf.one_hot(y, param_list["OUT_CHANNELS"], axis=-1))
y_test = y_test.batch(param_list["BATCH_SIZE"])

In [11]:
test_data = tf.data.Dataset.zip((x_test, y_test))

In [12]:
loss, acc = model.evaluate(test_data)



In [13]:
#y_pred = tf.argmax(model.predict(x_test), axis=-1).numpy()     # Unable to predict whole test_set at a time
y_pred = []
for x in x_test:
    y_pred.extend(tf.argmax(model.predict(x), axis=-1).numpy())
y_pred = np.array(y_pred)

In [14]:
y_true = np.array([yt for yt in y_test_slices.as_numpy_iterator()])

In [15]:
y_pred.shape, y_true.shape

((8945, 1), (8945, 1))

In [16]:
p, r, f = [], [], []
average_method = ["micro", "macro", "weighted"]

for method in average_method:
    precision = precision_score(np.ravel(y_true), np.ravel(y_pred), average=method)
    recall = recall_score(np.ravel(y_true), np.ravel(y_pred), average=method)
    f1 = f1_score(np.ravel(y_true), np.ravel(y_pred), average=method)
     
    p.append(precision)
    r.append(recall)
    f.append(f1)

In [17]:
report = pd.DataFrame(data=[p, r, f], columns=average_method, index=["precision", "recall", "f1"])
report

Unnamed: 0,micro,macro,weighted
precision,0.213304,0.001084,0.40275
recall,0.213304,0.000894,0.213304
f1,0.213304,0.000315,0.075158


In [18]:
accuracy = pd.DataFrame(data=[[loss, acc]], columns=["loss", "accuracy"])
accuracy

Unnamed: 0,loss,accuracy
0,5.986146,0.213304


In [19]:
report.to_csv("version/{}/report.csv".format(timestamp))
accuracy.to_csv("version/{}/accuracy.csv".format(timestamp), index=False)