In [1]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:32:13_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0


In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import concept_model
import helper
from utils.log import setup_logger
from utils.ood_utils import run_ood_over_batch
from utils.test_utils import get_measures
# from test_baselines import run_eval

import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K


import tensorflow.keras.utils as utils
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras.layers as layers

from utils.ood_utils import run_ood_over_batch
from utils.test_utils import get_measures
from utils.stat_utils import multivar_separa 

import os
import argparse
import logging
import numpy as np
import sys
import time
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


physical_devices = tf.config.experimental.list_physical_devices('GPU')
try:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

print(tf.config.experimental.list_physical_devices())

2024-08-11 11:59:17.505706: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-11 11:59:17.526388: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-11 11:59:17.532614: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-11 11:59:17.547915: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
class ARGS:
    def __init__(self):
        self.gpu = "0"
        self.batch_size = 256
        self.epoch = 22
        self.opt = "adam"
        self.thres = 0
        self.val_step = 2
        self.save_step = 1
        self.offset = 18
        self.trained = False
        self.num_concepts = 100

        self.coeff_concept = 10
        self.feat_l2 = False
        self.coeff_feat = 0.1
        self.feat_cosine = False
        self.coeff_cosine = 1
        self.ood = False
        self.score = None # "energy"
        self.coeff_score = 1 # 
        self.separability = False
        self.coeff_separa = 50

        self.num_hidden = 2

        self.out_data = "MSCOCO" # "augAwA"
        self.temperature_odin = 1000
        self.epsilon_odin = 0.0
        self.temperature_energy = 1
        
        self.name = "AwA2_2_baseline_normal" # AwA2_baseline, AwA2_feat_l2_0.1_ood_1_sep_50
        self.logdir = "results/"+self.name+"/train_logs"

args = ARGS()

# AwA2_2_feat_l2_0.1_ood_1_sep_50_normal

In [4]:
def get_data(bs, ood=True):
    """
    prepare data loaders for ID and OOD data (train/test)
    :param bs: batch size
    :ood: whether to load OOD data as well (False for baseline concept learning by Yeh et al.)
    """

    TRAIN_DIR = "data/AwA2/train"
    VAL_DIR = "data/AwA2/val"
    TEST_DIR = "data/AwA2/test"
    if args.out_data == 'MSCOCO':
        OOD_DIR = "data/MSCOCO"
    elif args.out_data == 'augAwA':
        OOD_DIR = "data/AwA2-train-fractals"

    TARGET_SIZE = (224, 224)
    BATCH_SIZE = bs
    BATCH_SIZE_OOD = bs

    print('Loading images through generators ...')
    train_datagen = ImageDataGenerator(rescale=1. / 255.,
                                       # rotation_range=40,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       shear_range=0.2,
                                       zoom_range=0.2,
                                       horizontal_flip=True)
    train_loader = train_datagen.flow_from_directory(TRAIN_DIR,
                                                    batch_size=BATCH_SIZE,
                                                    target_size=TARGET_SIZE,
                                                    class_mode='categorical',
                                                    shuffle=True)

    #print(train_generator.class_indices.items())

    datagen = ImageDataGenerator(rescale=1.0 / 255.)
    val_loader = datagen.flow_from_directory(VAL_DIR,
                                            batch_size=BATCH_SIZE,
                                            target_size=TARGET_SIZE,
                                            class_mode='categorical',
                                            shuffle=False)
    test_loader = datagen.flow_from_directory(TEST_DIR,
                                            batch_size=BATCH_SIZE,
                                            target_size=TARGET_SIZE,
                                            class_mode='categorical',
                                            shuffle=False)
    if ood:
        #numUpdates = int(NUM_TRAIN / BATCH_SIZE) # int(f_train.shape[0] / BATCH_SIZE)
        #NUM_OOD = 31706
        #BATCH_SIZE_OOD = int(NUM_OOD / numUpdates)
        OOD_loader = train_datagen.flow_from_directory(OOD_DIR, #datagen
                                                batch_size=BATCH_SIZE_OOD,
                                                target_size=TARGET_SIZE,
                                                class_mode=None, shuffle=True)
    else:
        OOD_loader = None

    return train_loader, val_loader, test_loader, OOD_loader


def get_class_labels(loader, savepath):
    """
    extract groundtruth class labels from data loader
    :param loader: data loader
    :param savepath: path to the numpy file
    """

    if os.path.exists(savepath):
        y = np.load(savepath)
    else:
        num_data = len(loader.filenames)
        y = []
        for (_, y_batch), _ in zip(loader, range(len(loader))):
            y.extend(y_batch)
       
        np.save(savepath, y)
    return y

def run_eval(feature_model, predict_model, in_loader, out_loader, logger, args, num_classes):
    in_scores = np.array([])
    for i, (x, y) in enumerate(in_loader):
        if i == len(in_loader):
            break
        score = run_ood_over_batch(x, feature_model, predict_model, args, num_classes).numpy()
        in_scores = np.concatenate([in_scores, score])
    out_scores = np.array([])
    for i, x in enumerate(out_loader):
        if i == len(in_loader):
            break
        score = run_ood_over_batch(x, feature_model, predict_model, args, num_classes).numpy()
        out_scores = np.concatenate([out_scores, score])
    in_examples = np.expand_dims(in_scores, axis=1)
    out_examples = np.expand_dims(out_scores, axis=1)
    auroc, aupr_in, aupr_out, fpr, thres95 = get_measures(in_examples, out_examples)
    return in_scores, out_scores, auroc, fpr, thres95

In [5]:
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

#if not os.path.exists(args.output_dir):
#    os.makedirs(args.output_dir)

if args.separability:
    args.ood = True
USE_OOD = args.ood
BATCH_SIZE = args.batch_size
EPOCH = args.epoch
THRESHOLD = args.thres
trained = args.trained
N_CONCEPT = args.num_concepts
offset = args.offset
topic_modelpath = os.path.join(args.logdir, args.name,'topic_epoch{}.weights.h5'.format(offset))
#topic_modelpath = os.path.join(args.logdir, args.name,'topic_latest.h5')
topic_savepath = os.path.join(args.logdir, args.name,'topic_vec_inceptionv3.npy')

logger = setup_logger(args)

train_loader, val_loader, test_loader, ood_loader =  get_data(BATCH_SIZE, ood=USE_OOD)

#print(train_generator.class_indices.items())
#assert ('_OOD', 0) in val_generator.class_indices.items()
#y_train = get_class_labels(train_loader, savepath='data/Animals_with_Attributes2/y_train.npy')
y_val = get_class_labels(val_loader, savepath='data/AwA2/y_val.npy')
y_test = get_class_labels(test_loader, savepath='data/AwA2/y_test.npy')

# preds_cls_idx = y_test.argmax(axis=-1)
# idx_to_cls = {v: k for k, v in test_generator.class_indices.items()}
# preds_cls = np.vectorize(idx_to_cls.get)(preds_cls_idx)
# filenames_to_cls = list(zip(test_generator.filenames, preds_cls))


# Loads model
feature_model, predict_model = helper.load_model_inception_new(train_loader, val_loader, \
           batch_size=BATCH_SIZE, input_size=(224,224), pretrain=True, \
           modelname='./results/AwA2/inceptionv3_AwA2_normal_epoch_40.weights.h5', split_idx=-5)

2024-08-11 11:59:25,578 [INFO] utils.log: <__main__.ARGS object at 0x7f5d2a28ad90>


Loading images through generators ...
Found 29841 images belonging to 50 classes.
Found 3709 images belonging to 50 classes.
Found 3772 images belonging to 50 classes.


2024-08-11 11:59:26.821584: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31141 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:1a:00.0, compute capability: 7.0




original model to be trained


  saveable.load_own_variables(weights_store.get(inner_path))
  self._warn_if_super_not_called()
I0000 00:00:1723370389.165482  211603 service.cc:146] XLA service 0x7f5c48002c60 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1723370389.165565  211603 service.cc:154]   StreamExecutor device (0): Tesla V100-SXM2-32GB, Compute Capability 7.0
2024-08-11 11:59:49.369948: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-11 11:59:50.430555: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8900


[1m 1/15[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m4:54[0m 21s/step - accuracy: 0.7695 - loss: 0.9815


I0000 00:00:1723370400.748995  211603 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 5s/step - accuracy: 0.8608 - loss: 0.7913
Loss of the trained original model: 0.7298120856285095
Accuracy of the trained original model: 0.8843353986740112





None


In [6]:
## Concept Learning
x, _ = test_loader.__next__()
f = feature_model(x[:10])
# topic model: intermediate feature --> concept score --> recovered feature --> prediction (50 classes)
topic_model_pr = concept_model.TopicModel(f, N_CONCEPT, THRESHOLD, predict_model, args.num_hidden)
_ = topic_model_pr(f)
print(topic_model_pr.build_graph(f).summary())

if os.path.exists(topic_modelpath):
    topic_model_pr.load_weights(topic_modelpath)
    logger.info(f'topic model loaded from {topic_modelpath}')

W0000 00:00:1723370478.291157  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.345403  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.345776  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.346091  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.346398  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.346715  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.347036  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.347353  211390 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370478.347678  211390 gp

2024-08-11 12:01:20,275 [INFO] utils.log: topic model loaded from results/AwA2_2_baseline_normal/train_logs/AwA2_2_baseline_normal/topic_epoch18.weights.h5


None


In [7]:
## Concept Learning

if args.opt =='sgd':
    """
    optimizer = SGD(lr=0.1)
    optimizer_state = [optimizer.iterations, optimizer.lr, optimizer.momentum, optimizer.decay]
    optimizer_reset = tf.compat.v1.variables_initializer(optimizer_state)
    """
    optimizer = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
elif args.opt =='adam':
    optimizer = Adam(learning_rate=0.01)
    optimizer_state = [optimizer.iterations, optimizer.learning_rate, optimizer.beta_1, optimizer.beta_2, optimizer.weight_decay]
    optimizer_reset = tf.compat.v1.variables_initializer(optimizer_state)

train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
test_acc_metric = keras.metrics.CategoricalAccuracy()
softmax = layers.Activation('softmax')

@tf.function
def train_step(x_in, y_in, x_out=None, thres=None):
    #tf.keras.applications.inception_v3.preprocess_input(x_in)
    f_in = feature_model(x_in)
    f_in_n = K.l2_normalize(f_in,axis=(3))


    obj_terms = {} # terms in the objective function
    COEFF_CONCEPT = args.coeff_concept #10 -> 5 -> 1 
    with tf.GradientTape() as tape:
        f_in_recov, logits_in, topic_vec_n = topic_model_pr(f_in, training=True)
        pred_in = softmax(logits_in) # class prediction using concept scores
        topic_prob_in_n = K.dot(f_in_n, topic_vec_n) # normalized concept scores

        # total loss
        CE_IN = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_in, pred_in))
        loss_coherency = tf.reduce_mean(tf.nn.top_k(K.transpose(K.reshape(topic_prob_in_n,(-1,N_CONCEPT))),k=10,sorted=True).values)
        loss_similarity = tf.reduce_mean(K.dot(K.transpose(topic_vec_n), topic_vec_n) - tf.eye(N_CONCEPT))
        loss = CE_IN - COEFF_CONCEPT*loss_coherency + COEFF_CONCEPT*loss_similarity  # baseline: Yeh et al.
        obj_terms['[ID] CE'] = CE_IN
        obj_terms['[ID] concept coherency'] = loss_coherency
        obj_terms['[ID] concept similarity'] = loss_similarity
        #print('y_in: '+type(y_in).__name__)
        #print('pred_in: '+type(pred_in).__name__)
        #print('CE_IN: '+type(CE_IN).__name__)
        #print('loss coher: '+type(loss_coherency).__name__)
        #print('loss_sim: '+type(loss_similarity).__name__)
        #print('loss: '+type(loss).__name__)
        
        if args.feat_l2:
            loss_l2 = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.pow(f_in-f_in_recov,2), axis=(1,2,3))))
            #loss_l2 = tf.reduce_mean(tf.reduce_sum(tf.pow(f_in-f_in_recov,2), axis=(1,2,3)))
            loss += args.coeff_feat*loss_l2 #0.07, 0.02
            obj_terms['feature L2'] = loss_l2

        if args.feat_cosine:
            loss_cosine = tf.reduce_mean(tf.keras.losses.cosine_similarity(f_in, f_in_recov)) # equivalent to: tf.reduce_mean(tf.reduce_sum(tf.math.multiply(f_in, f_in_recov),axis=(1,2,3))/(tf.sqrt(tf.reduce_sum(tf.pow(f_in,2),axis=(1,2,3)))*tf.sqrt(tf.reduce_sum(tf.pow(f_in_recov,2),axis=(1,2,3)))))
            loss_cosine = 1 - loss_cosine # cosine distance, range=[0, 2]
            loss += args.coeff_cosine*loss_cosine
            obj_terms['feature cosine distance'] = loss_cosine
        
        if args.score:
            s_in = run_ood_over_batch(x_in, feature_model, predict_model, args, num_classes=50)
            s_out = run_ood_over_batch(x_out, feature_model, predict_model, args, num_classes=50)

            if args.coeff_score > 0.0:
                # scores from OOD detector when using recovered features
                s_in_recov = run_ood_over_batch(x_in, feature_model, topic_model_pr, args, num_classes=50)
                s_out_recov = run_ood_over_batch(x_out, feature_model, topic_model_pr, args, num_classes=50)

                s_original = tf.concat((s_in, s_out), axis=0)
                s_recovered = tf.concat((s_in_recov, s_out_recov), axis=0)
                loss_score = tf.reduce_mean(tf.pow(s_original - s_recovered, 2))
                loss += args.coeff_score*loss_score
                obj_terms['score difference'] = loss_score

                """
                # Debugging
                auroc, aupr_in, aupr_out, fpr95, thres95 = get_measures(s_in.numpy()[:,None], s_out.numpy()[:,None])
                print(f'auroc: {auroc}, aupr in: {aupr_in}, aupr out: {aupr_out}, fpr95: {fpr95}')
                auroc, aupr_in, aupr_out, fpr95, thres95 = get_measures(s_in_recov.numpy()[:,None], s_out_rec
ov.numpy()[:,None])
                print(f'auroc: {auroc}, aupr in: {aupr_in}, aupr out: {aupr_out}, fpr95: {fpr95}')
                input()
                """
        
        if args.separability:
            f_out = feature_model(x_out)
            f_out_n = K.l2_normalize(f_out,axis=(3))
            _, logits_out, _ = topic_model_pr(f_out, training=True)
            #tf.debugging.assert_equal(topic_vec_n, topic_vec_n_out) 
            topic_prob_out_n = K.dot(f_out_n, topic_vec_n)
            

            # max --> smoothly approximated by logsumexp
            #T = tf.Variable(1e+3, dtype=tf.float32)
            T = 1e+3
            prob_max_in = 1/T*tf.math.reduce_logsumexp(T*topic_prob_in_n,axis=(1,2))
            prob_min_in = -1/T*tf.math.reduce_logsumexp(-T*topic_prob_in_n,axis=(1,2))

            ## concept scores of "true" ID set and "true" OOD set
            concept_in_true = tf.where(tf.abs(prob_max_in) > tf.abs(prob_min_in), prob_max_in, prob_min_in)
            prob_max_out = 1/T*tf.math.reduce_logsumexp(T*topic_prob_out_n,axis=(1,2))
            prob_min_out = -1/T*tf.math.reduce_logsumexp(-T*topic_prob_out_n,axis=(1,2))
            concept_out_true = tf.where(tf.abs(prob_max_out) > tf.abs(prob_min_out), prob_max_out, prob_min_out)
            
            ## concept scores of "detected" ID set and "detected" OOD set
            concept_in = tf.concat([concept_in_true[s_in>=thres], concept_out_true[s_out>=thres]], axis=0) 
            concept_out = tf.concat([concept_in_true[s_in<thres], concept_out_true[s_out<thres]], axis=0)

            # global separability
            loss_separa = multivar_separa(concept_in, concept_out)
            loss -= args.coeff_separa*loss_separa
            obj_terms['separability'] = loss_separa

    obj_terms['total loss.......'] = loss
    train_acc_metric.update_state(y_in, logits_in)
    #print(obj_terms)

    # calculate the gradients using our tape and then update the model weights
    grads = tape.gradient(loss, topic_model_pr.trainable_variables)
    optimizer.apply_gradients(zip(grads, topic_model_pr.trainable_variables))
    #print(type(loss).__name__, ":", grads)
    #input()
    return obj_terms

if not trained:
    for layer in topic_model_pr.layers[:-1]:
        #print(layer.trainable)
        layer.trainable = True

    # check all weights are included in trainable_variables
    # for i, var in enumerate(topic_model_pr.trainable_variables):
    #     print(topic_model_pr.trainable_variables[i].name)


    if args.score and args.separability: # identify threshold from held-out set
        datagen = ImageDataGenerator(rescale=1.0 / 255.)
        if args.out_data == 'MSCOCO':
            out_gen = datagen.flow_from_directory('data/MSCOCO/test',batch_size=150,target_size=(224,224),class_mode=None,shuffle=False)
        elif args.out_data == 'augAwA':
            out_gen = datagen.flow_from_directory('data/AwA2-test-fractals',batch_size=150,target_size=(224,224),class_mode=None,shuffle=False)
        _, _, _, _, thres = run_eval(feature_model, predict_model, val_loader, out_gen, logger, args, 50)
        thres = float(thres)
    else:
        thres = None

    df_obj_terms = pd.DataFrame()
    for epoch in range(offset+1, offset+EPOCH+1):
        logger.info(f"\n[INFO] starting epoch {epoch}/{offset+EPOCH} ---------------------------------")
        sys.stdout.flush()
        epochStart = time.time()
        
        for step, (x_in, y_in) in enumerate(train_loader):
            
            step += 1 # starts from 1
            if step > len(train_loader):
                break

            if USE_OOD:
                x_out = ood_loader.__next__()
                obj_terms = train_step(x_in, y_in, x_out, thres)
            else:
                obj_terms = train_step(x_in, y_in)

            # Log every 50 batches
            if step % 20 == 0:
                #print(topic_model_pr.layers[0].get_weights()[0])
                for term in obj_terms:
                    logger.info(f'[STEP{step}] {term}: {obj_terms[term]}')
            for term in obj_terms:
                obj_terms[term] = obj_terms[term].numpy()
            obj_terms["epoch"] = epoch
            obj_terms["step"] = step
            df_obj = pd.Series(obj_terms)
            df_obj_terms = pd.concat([df_obj_terms, pd.DataFrame(df_obj).T], axis=0)
        
        train_acc = train_acc_metric.result()
        logger.info("Training acc over epoch: %.4f" % (float(train_acc),))
        
        # show timing information for the epoch
        epochEnd = time.time()
        elapsed = (epochEnd - epochStart) / 60.0
        logger.info("Time taken: %.2f minutes" % (elapsed))

        df_obj_terms = df_obj_terms.reset_index(drop=True)
        df_obj_terms_melt = pd.melt(df_obj_terms, id_vars=["epoch", "step"], 
                                    value_vars=[col for col in df_obj_terms.columns if col in 
                                                ['[ID] CE', '[ID] concept coherency', 'feature L2', 
                                                 '[ID] concept similarity', 'ood score difference', 
                                                 'id & ood separability', 'total loss']],
                                    var_name="loss_term", value_name="loss_value")

        plt.figure()
        sns.lineplot(data=df_obj_terms_melt, x="epoch", y="loss_value", hue="loss_term")
        plt.savefig(args.logdir+"/train_loss.png")
        plt.close()
        plt.figure()
        sns.lineplot(data=df_obj_terms_melt[(df_obj_terms_melt["loss_term"]=='[ID] CE') | (df_obj_terms_melt["loss_term"]=='[ID] concept coherency') | 
                                            (df_obj_terms_melt["loss_term"]=='[ID] concept similarity')], 
                     x="epoch", y="loss_value", hue="loss_term")
        plt.savefig(args.logdir+"/train_loss1.png")
        plt.close()


        # Reset training metrics at the end of each epoch
        train_acc_metric.reset_state()
        if epoch % args.save_step == 0:
            topic_model_pr.save_weights(os.path.join(args.logdir, args.name,'topic_epoch{}.weights.h5'.format(epoch)))

        if epoch % args.val_step == 0:
            _, logits_val, _ = topic_model_pr(feature_model.predict(val_loader), training=False)
            pred_val = softmax(logits_val)
            val_acc_metric.update_state(y_val, logits_val)
            val_acc = val_acc_metric.result()
            logger.info("[EPOCH %d] Validation acc: %.4f" % (epoch, float(val_acc)))
            val_acc_metric.reset_state()
            del logits_val
        
        logger.flush()


topic_vec = topic_model_pr.layers[0].get_weights()[0]   # 1, (2048, num_concepts)
# recov_vec = topic_model_pr.layers[-3].get_weights()[0]
topic_vec_n = topic_vec/(np.linalg.norm(topic_vec,axis=0,keepdims=True)+1e-9)
np.save(topic_savepath,topic_vec)
# np.save('results/Animals_with_Attributes2_energy_COCO/recov_vec_inceptionv3.npy',recov_vec)

assert np.shape(topic_vec)[1] == N_CONCEPT
# topic_model_pr.evaluate(f_test, y_test)
# f_val_recovered = topic_model_pr.predict(f_val)


f_test = feature_model.predict(test_loader)
_, logits_test, _ = topic_model_pr(f_test, training=False)
pred_test = softmax(logits_test)
test_acc_metric.update_state(y_test, logits_test)
test_acc = test_acc_metric.result()
logger.info('[ID TEST] Accuracy of topic model on test set: %f' %test_acc)

logger.flush()

2024-08-11 12:01:20,325 [INFO] utils.log: 
[INFO] starting epoch 19/40 ---------------------------------
W0000 00:00:1723370492.941789  211603 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370492.943439  211603 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370492.945017  211603 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370492.946552  211603 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370492.948489  211603 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370492.950383  211603 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370492.951988  211603 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723370492.953759  211603 gpu_timer.cc:114]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 5s/step


2024-08-11 12:34:55,273 [INFO] utils.log: [EPOCH 20] Validation acc: 0.8585
2024-08-11 12:34:55,277 [INFO] utils.log: 
[INFO] starting epoch 21/40 ---------------------------------
2024-08-11 12:37:39,229 [INFO] utils.log: [STEP20] [ID] CE: 0.5222805142402649
2024-08-11 12:37:39,232 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7310627698898315
2024-08-11 12:37:39,234 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.025519663468003273
2024-08-11 12:37:39,236 [INFO] utils.log: [STEP20] total loss.......: -6.533150672912598
2024-08-11 12:40:23,361 [INFO] utils.log: [STEP40] [ID] CE: 0.615963339805603
2024-08-11 12:40:23,366 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7395974397659302
2024-08-11 12:40:23,368 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.02593456394970417
2024-08-11 12:40:23,370 [INFO] utils.log: [STEP40] total loss.......: -6.520665168762207
2024-08-11 12:43:07,733 [INFO] utils.log: [STEP60] [ID] CE: 0.6455304622650146
2024-08-11 12:43:07,737 

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 13:08:41,874 [INFO] utils.log: [EPOCH 22] Validation acc: 0.8576
2024-08-11 13:08:41,876 [INFO] utils.log: 
[INFO] starting epoch 23/40 ---------------------------------
2024-08-11 13:11:27,608 [INFO] utils.log: [STEP20] [ID] CE: 0.6553700566291809
2024-08-11 13:11:27,611 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7390520572662354
2024-08-11 13:11:27,613 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.02487799897789955
2024-08-11 13:11:27,615 [INFO] utils.log: [STEP20] total loss.......: -6.48637056350708
2024-08-11 13:14:12,954 [INFO] utils.log: [STEP40] [ID] CE: 0.6879659295082092
2024-08-11 13:14:12,958 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7113439440727234
2024-08-11 13:14:12,960 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.025758959352970123
2024-08-11 13:14:12,963 [INFO] utils.log: [STEP40] total loss.......: -6.16788387298584
2024-08-11 13:16:58,600 [INFO] utils.log: [STEP60] [ID] CE: 0.46778634190559387
2024-08-11 13:16:58,603 

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 13:42:33,109 [INFO] utils.log: [EPOCH 24] Validation acc: 0.8579
2024-08-11 13:42:33,113 [INFO] utils.log: 
[INFO] starting epoch 25/40 ---------------------------------
2024-08-11 13:45:18,003 [INFO] utils.log: [STEP20] [ID] CE: 0.49730244278907776
2024-08-11 13:45:18,006 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7446531057357788
2024-08-11 13:45:18,008 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.026594549417495728
2024-08-11 13:45:18,011 [INFO] utils.log: [STEP20] total loss.......: -6.68328332901001
2024-08-11 13:48:03,164 [INFO] utils.log: [STEP40] [ID] CE: 0.5792363286018372
2024-08-11 13:48:03,168 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7294885516166687
2024-08-11 13:48:03,170 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.024832503870129585
2024-08-11 13:48:03,173 [INFO] utils.log: [STEP40] total loss.......: -6.467324256896973
2024-08-11 13:50:47,550 [INFO] utils.log: [STEP60] [ID] CE: 0.5715565085411072
2024-08-11 13:50:47,55

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 14:16:21,613 [INFO] utils.log: [EPOCH 26] Validation acc: 0.8665
2024-08-11 14:16:21,615 [INFO] utils.log: 
[INFO] starting epoch 27/40 ---------------------------------
2024-08-11 14:19:06,860 [INFO] utils.log: [STEP20] [ID] CE: 0.5741291046142578
2024-08-11 14:19:06,863 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7394365668296814
2024-08-11 14:19:06,866 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.025446033105254173
2024-08-11 14:19:06,868 [INFO] utils.log: [STEP20] total loss.......: -6.565776348114014
2024-08-11 14:21:51,788 [INFO] utils.log: [STEP40] [ID] CE: 0.6019413471221924
2024-08-11 14:21:51,791 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7330456972122192
2024-08-11 14:21:51,793 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.02586238645017147
2024-08-11 14:21:51,795 [INFO] utils.log: [STEP40] total loss.......: -6.469891548156738
2024-08-11 14:24:37,028 [INFO] utils.log: [STEP60] [ID] CE: 0.6877170205116272
2024-08-11 14:24:37,032

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 14:50:13,565 [INFO] utils.log: [EPOCH 28] Validation acc: 0.8563
2024-08-11 14:50:13,568 [INFO] utils.log: 
[INFO] starting epoch 29/40 ---------------------------------
2024-08-11 14:52:58,242 [INFO] utils.log: [STEP20] [ID] CE: 0.5278351902961731
2024-08-11 14:52:58,245 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7292600274085999
2024-08-11 14:52:58,248 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.025862833485007286
2024-08-11 14:52:58,251 [INFO] utils.log: [STEP20] total loss.......: -6.506136417388916
2024-08-11 14:55:43,291 [INFO] utils.log: [STEP40] [ID] CE: 0.48582571744918823
2024-08-11 14:55:43,294 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7164028286933899
2024-08-11 14:55:43,296 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.02672719769179821
2024-08-11 14:55:43,299 [INFO] utils.log: [STEP40] total loss.......: -6.410930633544922
2024-08-11 14:58:28,391 [INFO] utils.log: [STEP60] [ID] CE: 0.6348688006401062
2024-08-11 14:58:28,39

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 15:24:03,468 [INFO] utils.log: [EPOCH 30] Validation acc: 0.8644
2024-08-11 15:24:03,472 [INFO] utils.log: 
[INFO] starting epoch 31/40 ---------------------------------
2024-08-11 15:26:48,932 [INFO] utils.log: [STEP20] [ID] CE: 0.594481885433197
2024-08-11 15:26:48,934 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7376254796981812
2024-08-11 15:26:48,937 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.026824191212654114
2024-08-11 15:26:48,939 [INFO] utils.log: [STEP20] total loss.......: -6.51353120803833
2024-08-11 15:29:33,665 [INFO] utils.log: [STEP40] [ID] CE: 0.6154650449752808
2024-08-11 15:29:33,667 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7442961931228638
2024-08-11 15:29:33,670 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.025445206090807915
2024-08-11 15:29:33,672 [INFO] utils.log: [STEP40] total loss.......: -6.573044300079346
2024-08-11 15:32:17,765 [INFO] utils.log: [STEP60] [ID] CE: 0.6165094375610352
2024-08-11 15:32:17,768 

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 15:57:53,794 [INFO] utils.log: [EPOCH 32] Validation acc: 0.8620
2024-08-11 15:57:53,798 [INFO] utils.log: 
[INFO] starting epoch 33/40 ---------------------------------
2024-08-11 16:00:38,386 [INFO] utils.log: [STEP20] [ID] CE: 0.38344910740852356
2024-08-11 16:00:38,389 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7404606342315674
2024-08-11 16:00:38,392 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.024701813235878944
2024-08-11 16:00:38,394 [INFO] utils.log: [STEP20] total loss.......: -6.774138927459717
2024-08-11 16:03:23,566 [INFO] utils.log: [STEP40] [ID] CE: 0.539472222328186
2024-08-11 16:03:23,569 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7312296628952026
2024-08-11 16:03:23,571 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.02510889247059822
2024-08-11 16:03:23,573 [INFO] utils.log: [STEP40] total loss.......: -6.521735668182373
2024-08-11 16:06:08,938 [INFO] utils.log: [STEP60] [ID] CE: 0.5675948858261108
2024-08-11 16:06:08,941

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 16:31:46,901 [INFO] utils.log: [EPOCH 34] Validation acc: 0.8547
2024-08-11 16:31:46,905 [INFO] utils.log: 
[INFO] starting epoch 35/40 ---------------------------------
2024-08-11 16:34:32,540 [INFO] utils.log: [STEP20] [ID] CE: 0.6270381212234497
2024-08-11 16:34:32,543 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.728710412979126
2024-08-11 16:34:32,545 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.025896230712532997
2024-08-11 16:34:32,547 [INFO] utils.log: [STEP20] total loss.......: -6.401103973388672
2024-08-11 16:37:18,204 [INFO] utils.log: [STEP40] [ID] CE: 0.6464423537254333
2024-08-11 16:37:18,207 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7238888144493103
2024-08-11 16:37:18,209 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.025722330436110497
2024-08-11 16:37:18,211 [INFO] utils.log: [STEP40] total loss.......: -6.3352227210998535
2024-08-11 16:40:03,464 [INFO] utils.log: [STEP60] [ID] CE: 0.44699627161026
2024-08-11 16:40:03,467 

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 17:05:46,384 [INFO] utils.log: [EPOCH 36] Validation acc: 0.8633
2024-08-11 17:05:46,387 [INFO] utils.log: 
[INFO] starting epoch 37/40 ---------------------------------
2024-08-11 17:08:31,176 [INFO] utils.log: [STEP20] [ID] CE: 0.6632957458496094
2024-08-11 17:08:31,179 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7447587847709656
2024-08-11 17:08:31,181 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.025655727833509445
2024-08-11 17:08:31,183 [INFO] utils.log: [STEP20] total loss.......: -6.527734756469727
2024-08-11 17:11:17,609 [INFO] utils.log: [STEP40] [ID] CE: 0.4868473410606384
2024-08-11 17:11:17,612 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7374150156974792
2024-08-11 17:11:17,614 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.02612275630235672
2024-08-11 17:11:17,616 [INFO] utils.log: [STEP40] total loss.......: -6.626075267791748
2024-08-11 17:14:02,997 [INFO] utils.log: [STEP60] [ID] CE: 0.5569136142730713
2024-08-11 17:14:03,001

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 17:39:45,390 [INFO] utils.log: [EPOCH 38] Validation acc: 0.8630
2024-08-11 17:39:45,394 [INFO] utils.log: 
[INFO] starting epoch 39/40 ---------------------------------
2024-08-11 17:42:30,116 [INFO] utils.log: [STEP20] [ID] CE: 0.6996302604675293
2024-08-11 17:42:30,119 [INFO] utils.log: [STEP20] [ID] concept coherency: 0.7180746793746948
2024-08-11 17:42:30,122 [INFO] utils.log: [STEP20] [ID] concept similarity: 0.026308557018637657
2024-08-11 17:42:30,124 [INFO] utils.log: [STEP20] total loss.......: -6.218031406402588
2024-08-11 17:45:15,993 [INFO] utils.log: [STEP40] [ID] CE: 0.7252137660980225
2024-08-11 17:45:15,996 [INFO] utils.log: [STEP40] [ID] concept coherency: 0.7421110272407532
2024-08-11 17:45:15,999 [INFO] utils.log: [STEP40] [ID] concept similarity: 0.026377279311418533
2024-08-11 17:45:16,001 [INFO] utils.log: [STEP40] total loss.......: -6.432123184204102
2024-08-11 17:48:01,491 [INFO] utils.log: [STEP60] [ID] CE: 0.6589313745498657
2024-08-11 17:48:01,49

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 5s/step


2024-08-11 18:13:49,739 [INFO] utils.log: [EPOCH 40] Validation acc: 0.8673
  self._warn_if_super_not_called()


[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m87s[0m 6s/step


2024-08-11 18:15:28,173 [INFO] utils.log: [ID TEST] Accuracy of topic model on test set: 0.875398


[]