In [None]:
from datetime import datetime
import os

import numpy as np
import tensorflow as tf
import csv

from models.resnet_config import Config
from models.resnet_utils import *
from models.resnet_utils import _max_pool

BATCH_SIZE = 50
NUM_CLASSES = 188

In [None]:
def _get_filenames_and_classes(dataset_dir):
    """Returns a list of filenames and inferred class names.
    Args:
        dataset_dir: A directory containing a set of subdirectories representing
        class names. Each subdirectory should contain PNG or JPG encoded images.
    Returns:
        A list of image file paths, relative to `dataset_dir` and the list of
        subdirectories, representing class names.
    """


    directories = []
    class_names = []
    test_size = []
    test_count = []

    with open('/data/dict/PE/pe_meta.csv', 'rb') as csvfile:
        reader = csv.DictReader(csvfile)
        flist = []
        for row in reader:
            #print(row['CLASS'])
            if len(class_names) < NUM_CLASSES:
                path = os.path.join(dataset_dir, row['Label'])
                if os.path.isdir(path):
                    directories.append(path)
                    class_names.append(row['Label'])
                    test_size.append(int(row['Number']))
                    test_count.append(0)



    train_filenames = []
    test_filenames = []
    return train_filenames, test_filenames, sorted(class_names)

In [None]:
training_filenames, testing_filenames, class_names = _get_filenames_and_classes("/data/dict/PE/raw")
class_names_to_ids = dict(zip(class_names, range(len(class_names))))
print(class_names_to_ids)


In [None]:
def inference(x, is_training,
              num_classes=25,
              num_blocks=[3, 4, 6, 3],
              use_bias=False,
              bottleneck=True):
    c = Config()
    c['bottleneck'] = bottleneck
    c['is_training'] = tf.convert_to_tensor(is_training,
                                            dtype='bool',
                                            name='is_training')
    c['ksize'] = 3
    c['stride'] = 1
    c['use_bias'] = use_bias
    c['fc_units_out'] = num_classes
    c['num_blocks'] = num_blocks
    c['stack_stride'] = 2

    with tf.variable_scope('scale1'):
        c['conv_filters_out'] = 64
        c['ksize'] = 7
        c['stride'] = 2
        x = conv(x, c)
        x = bn(x, c)
        x = activation(x)

    with tf.variable_scope('scale2'):
        x = _max_pool(x, ksize=3, stride=2)
        c['num_blocks'] = num_blocks[0]
        c['stack_stride'] = 1
        c['block_filters_internal'] = 64
        x = stack(x, c)

    with tf.variable_scope('scale3'):
        c['num_blocks'] = num_blocks[1]
        c['block_filters_internal'] = 128
        assert c['stack_stride'] == 2
        x = stack(x, c)

    with tf.variable_scope('scale4'):
        c['num_blocks'] = num_blocks[2]
        c['block_filters_internal'] = 256
        x = stack(x, c)

    with tf.variable_scope('scale5'):
        c['num_blocks'] = num_blocks[3]
        c['block_filters_internal'] = 512
        x = stack(x, c)

    x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool")

    if num_classes != None:
        with tf.variable_scope('fc'):
            x = fc(x, c)

    return x

In [None]:
def build_model(input_):
    with tf.variable_scope("Resnet"):

        logits = inference(input_,
                                num_classes=NUM_CLASSES,
                                is_training=True,
                                bottleneck=False,
                                num_blocks=[2, 2, 2, 2])
        return logits

In [None]:
def get_label(filenames):
    array = np.array([class_names_to_ids[label_string.split('/')[-2]] for label_string in filenames], dtype=np.int64)
    return array

In [None]:
def get_KLD_from_specialist(specialist, target, logits, num_classes, batch_size):
    s1 = [47, 166, 85]
    dustbin = []
    for idx in range(num_classes):
        if idx not in specialist:
            dustbin.append(idx)
    special = tf.constant(specialist, dtype=tf.int64)
    dust = tf.constant(dustbin, dtype=tf.int64)

    special_batch = tf.nn.embedding_lookup(logits[0], special)
    dust_batch = tf.nn.embedding_lookup(logits[0], dust)
    dust_batch = tf.reduce_sum(dust_batch)
    whole_batch = [tf.concat(0, [special_batch, [dust_batch]])]

    for idx in range(1, batch_size):
        special_batch = tf.nn.embedding_lookup(logits[idx], special)
        dust_batch = tf.nn.embedding_lookup(logits[idx], dust)
        dust_batch = tf.reduce_sum(dust_batch)
        whole_batch = tf.concat(0, [whole_batch, [tf.concat(0, [special_batch, [dust_batch]])]])

    KLD = tf.reduce_mean(target*tf.log(target)-(target*tf.log(whole_batch)), axis=1)
    # [batchsize * 1]
    return KLD

In [None]:
def build_loss(logits, ge, sp1, sp2):

    #gen
    logits = tf.nn.softmax(logits)
    #ge = tf.nn.softmax(ge)
    ge = tf.nn.softmax(tf.div(ge, 2))
    sp1 = tf.nn.softmax(tf.div(sp1, 2))
    sp2 = tf.nn.softmax(tf.div(sp2, 2))

    #ce0 = tf.reduce_mean(ge*tf.log(ge) - (ge*tf.log(logits)))
    ce0 = tf.reduce_mean(ge*tf.log(ge) - (ge*tf.log(logits)), axis=1)

    print("ce_0 :", ce0.get_shape())
    ##
    s1 = [47, 166, 85]
    s2 = [19, 166, 83]
    #sp1 - Downloader : 47, Virut : 166, MSIL : 85
    #sp2 - Bundler : 19, Virut : 166, LoadMoney : 83

    lookup_table = np.zeros((188, 3))
    for idx in range(188):
        lookup_table[idx][0] = 1
    for idx in s1:
        lookup_table[idx][1] = 0
    for idx in s2:
        lookup_table[idx][2] = 0

    lookup_table = tf.constant(lookup_table, dtype=tf.float32)

    ce1 = get_KLD_from_specialist(specialist = s1, logits = logits, target = sp1, num_classes = NUM_CLASSES, batch_size = BATCH_SIZE)

    ce2 = get_KLD_from_specialist(specialist = s2, logits = logits, target = sp2, num_classes = NUM_CLASSES, batch_size = BATCH_SIZE)

    ce_list = tf.stack((ce0, ce1, ce2))

    #print("ce_list :", ce_list.get_shape())

    sp_loss = 0

    #print("sp_set :", tf.nn.embedding_lookup(lookup_table, 166).get_shape())

    for idx in range(BATCH_SIZE):
        gen_predict = tf.arg_max(ge[idx], dimension=0)
        sp_set = tf.nn.embedding_lookup(lookup_table, gen_predict)
        k = tf.reduce_sum(sp_set)
        tmp_loss = tf.reduce_sum(sp_set*tf.transpose(ce_list)[idx])
        sp_loss += tmp_loss/k

    #sp_loss = (sp_loss/BATCH_SIZE)

    cross_entropy_mean = sp_loss

    return cross_entropy_mean

In [None]:
def build_accuracy(logits, labels):
    correct_prediction = tf.nn.in_top_k(logits, labels, 1)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuarcy', accuracy)
    return accuracy

In [None]:
graph = tf.Graph()
with graph.as_default():
    filename_queue = tf.train.input_producer(path, num_epochs=1, capacity=1, shuffle=True, seed=0)
    gen_queue = tf.train.input_producer(gen, num_epochs=1, capacity=1, shuffle=True, seed=0)
    sp1_queue = tf.train.input_producer(sp1, num_epochs=1, capacity=1, shuffle=True, seed=0)
    sp2_queue = tf.train.input_producer(sp2, num_epochs=1, capacity=1, shuffle=True, seed=0)

    fn = filename_queue.dequeue()
    ge = gen_queue.dequeue()

    s1 = sp1_queue.dequeue()
    s2 = sp2_queue.dequeue()

    string_tensor = tf.read_file(fn)

    content = tf.decode_raw(string_tensor, tf.uint8)
    zero_padding_size = tf.constant(224) - tf.mod(tf.shape(content), tf.constant(224))
    zero_padding = tf.zeros(zero_padding_size, dtype=tf.uint8)
    content = tf.concat(0, [content, zero_padding])

    content = tf.reshape(content, [-1, 224, 1])
    content = tf.image.resize_images(content, [224, 224], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    img = tf.div(tf.sub(tf.cast(content, tf.float32), 127.5), 127.5)


In [None]:
with graph.as_default():
    fn_batch, img_batch, ge_batch, sp1_batch, sp2_batch = tf.train.batch([fn, img, ge, s1, s2], BATCH_SIZE, capacity=1, num_threads=1)
    #fn_batch, ge_batch = tf.train.batch([fn, ge], 50, capacity=1, num_threads=1)

    label_batch = tf.py_func(get_label, [fn_batch], [tf.int64])[0]

    logits = build_model(img_batch)
    loss = build_loss(logits, ge_batch, sp1_batch, sp2_batch)
    accuracy = build_accuracy(logits, label_batch)
    #accuracy = build_accuracy(ge_batch, label_batch)

    global_step = tf.Variable(0, name="global_step", trainable=False)

    lr = tf.Variable(0.0001, trainable=False)
    opt = tf.train.AdamOptimizer(lr).minimize(loss, global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess = tf.Session()
    sess.run(init_op)

    saver = tf.train.Saver()
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter("/data/tensorboard_log/dict/dist_0324_2/", sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    checkpoint_dir = "/data/dict/checkpoint_dist_0324_2/"

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        print('no checkpoint found...')

    summary_iter = 10
    save_iter = 100

    valacc = []
    try:
        count = 0
        while not coord.should_stop():
            #value = sess.run([fn_batch, label_batch, img_batch, ge_batch, sp1_batch, sp2_batch])

            _,  _loss, acc, merged, _global_step = sess.run([opt, loss, accuracy, merged_summary, global_step])
            print('loss : %g, accuracy : %g'%(_loss, acc))
            valacc.append(acc)
            #acc, _global_step = sess.run([accuracy, global_step])
            #print('accuracy : %g'%(acc))
            if _global_step % summary_iter == 0:
                writer.add_summary(merged, _global_step)
                print('write summary')
            if _global_step % save_iter == 0:
                saver.save(sess, os.path.join(checkpoint_dir,'model.ckpt'), _global_step)
                print('save checkpoint')

    except tf.errors.OutOfRangeError as e:
        saver.save(sess, os.path.join(checkpoint_dir,'model.ckpt'), _global_step)
        print('eval accuracy : ',np.mean(valacc))
        print("End of epochs")

    except Exception as e:
        print(e)

    finally:
        coord.request_stop()
        print("count: {}".format(count))