In [2]:
import tensorflow as tf
import numpy as np
import random
NUM_CLASSES = 20

class MLP:
    def __init__(self, vocab_size, hidden_size):
        self._vocab_size = vocab_size
        self._hidden_size = hidden_size
        
    def build_graph(self):
        self._X = tf.placeholder(tf.float32, shape=[None, self._vocab_size])
        self._real_Y = tf.placeholder(tf.int32, shape=[None, ])
        
        weights_1 = tf.get_variable(
            name='weights_input_hidden',
            shape=(self._vocab_size, self._hidden_size),
            initializer=tf.random_normal_initializer(seed=2018)
        )
        
        biases_1 = tf.get_variable(
            name='biases_input_hidden',
            shape=(self._hidden_size),
            initializer=tf.random_normal_initializer(seed=2018)
        )
        
        weights_2 = tf.get_variable(
            name='weights_hidden_output',
            shape=(self._hidden_size, NUM_CLASSES),
            initializer=tf.random_normal_initializer(seed=2018)
        )
        
        biases_2 = tf.get_variable(
            name='biases_hidden_output',
            shape=(NUM_CLASSES),
            initializer=tf.random_normal_initializer(seed=2018)
        )
        
        hidden = tf.matmul(self._X, weights_1) + biases_1
        hidden = tf.sigmoid(hidden)
        logits = tf.matmul(hidden, weights_2) + biases_2
        
        labels_one_hot = tf.one_hot(indices=self._real_Y, depth=NUM_CLASSES, dtype=tf.float32)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels_one_hot, logits=logits)
        loss = tf.reduce_mean(loss)
        
        probs = tf.nn.softmax(logits)
        predicted_labels = tf.argmax(probs, axis=1)
        predicted_labels = tf.squeeze(predicted_labels)
        
        return predicted_labels, loss
    
    def trainer(self, loss, learning_rate):
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
        return train_op
    

class DataReader:
    def __init__(self, data_path, batch_size, vocab_size):
        self._batch_size = batch_size
        with open(data_path) as f:
            d_lines = f.read().splitlines()
            
        self._data = []
        self._labels = []
        for data_id, line in enumerate(d_lines):
            vector = [0.0 for _ in range(vocab_size)]
            features = line.split('<fff>')
            label, doc_id = int(features[0]), int(features[1])
            tokens = features[2].split()
            for token in tokens:
                index, value = int(token.split(':')[0]), float(token.split(':')[1])
                vector[index] = value
            self._data.append(vector)
            self._labels.append(label)
        
        self._data = np.array(self._data)
        self._labels = np.array(self._labels)
        
        self._num_epoch = 0
        self._batch_id = 0
            
        
    def next_batch(self):
        start = self._batch_id * self._batch_size
        end = start + self._batch_size
        self._batch_id += 1
        
        if end + self._batch_size > len(self._data):
#         if True:
            end = len(self._data)
            self._num_epoch += 1
            self._batch_id = 0
            indices = list(range(len(self._data)))
            random.seed(2018)
            random.shuffle(indices)
            self._data, self._labels = self._data[indices], self._labels[indices]
        
        return self._data[start:end], self._labels[start:end]

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
def load_dataset():
    train_data_reader = DataReader(
        data_path = './datasets/20news-train-tfidf.txt',
        batch_size = 50,
        vocab_size = vocab_size
    )
    test_data_reader = DataReader(
        data_path = './datasets/20news-test-tfidf.txt',
        batch_size = 50,
        vocab_size = vocab_size
    )
    return train_data_reader, test_data_reader

def save_parameters(name, value, epoch):
    filename = name.replace(':', '-colon-') + '-epoch-{}.txt'.format(epoch)
    if len(value.shape) == 1:
        string_form = ','.join([str(number) for number in value])
    else:
        string_form = '\n'.join([','.join([str(number) for number in value[row]]) for row in range(value.shape[0])])
    
    with open('./saved-paras/' + filename, 'w') as f:
        f.write(string_form)
        
def restore_parameters(name, epoch):
    filename = name.replace(':', '-colon-') + '-epoch-{}.txt'.format(epoch)
    with open('./saved-paras/' + filename) as f:
        lines = f.read().splitlines()
    if len(lines) == 1:
        value = [float(number) for number in lines[0].split(',')]
    else:
        value = [[float(number) for number in lines[row].split(',')] for row in range(len(lines))]

    return value



# create computation graph
with open('./datasets/words_idfs.txt') as f:
    vocab_size = len(f.read().splitlines())
    
mlp = MLP(
    vocab_size = vocab_size,
    hidden_size = 50
)
predicted_labels, loss = mlp.build_graph()
train_op = mlp.trainer(loss = loss, learning_rate=0.1)

with tf.Session() as sess:
    train_data_reader, test_data_reader = load_dataset()
    step, MAX_STEP = 0, 1000 ** 2
    
    sess.run(tf.global_variables_initializer())
    
#     # restore params
#     trainable_variables = tf.trainable_variables()
#     for variable in trainable_variables:
#         saved_value = restore_parameters(variable.name, epoch)
#         assign_op = variable.assign(saved_value)
#         sess.run(assign_op)
    
    while step < MAX_STEP:
        train_data, train_labels = train_data_reader.next_batch()
        plabels_eval, loss_eval, _ = sess.run(
            [predicted_labels, loss, train_op],
            feed_dict = {
                mlp._X: train_data,
                mlp._real_Y: train_labels
            }
        )
        step += 1
        print("step: {}, loss: {}".format(step, loss_eval))
        
        # save params
        trainable_variables = tf.trainable_variables()
        for variable in trainable_variables:
            save_parameters(
                name=variable.name,
                value=variable.eval(),
                epoch=train_data_reader._num_epoch
            )
            

# evaluate model
test_data_reader = DataReader(
    data_path = './datasets/20news-test_tfidf.txt',
    batch_size = 50,
    vocab_size = vocab_size
)
with tf.Session() as sess:
    epoch = 10
    
    trainable_variables = tf.trainable_variables()
    for variable in trainable_variables:
        saved_value = restore_parameters(variables.name, epoch)
        assign_op = variable.assign(saved_value)
        sess.run(assign_op)
        
    num_true_preds = 0
    while True:
        test_data, test_labels = test_data_reader.next_batch()
        test_plabels_eval = sess.run(
            predicted_labels,
            feed_dict={
                mlp._X: test_data,
                mlp._real_Y: test_labels
            }
        )
        matches = np.equal(test_plabels_eval, test_labels)
        num_true_preds += np.sum(matches.astype(float))
        
        if test_data_reader._batch_id == 0:
            break
            
    print('Epoch: ', epoch)
    print('Accuracy on test data: ', num_true_preds / len(test_data_reader._data))

Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

step: 1, loss: 0.5065922737121582
step: 2, loss: 0.00016962051449809223
step: 3, loss: 8.225404144468484e-07
step: 4, loss: 4.291533528544278e-08
step: 5, loss: 0.0
step: 6, loss: 0.0
step: 7, loss: 0.0
step: 8, loss: 0.0
step: 9, loss: 0.0
step: 10, loss: 15.89012336730957
step: 11, loss: 36.74134063720703
step: 12, loss: 32.45412063598633
step: 13, loss: 26.995887756347656
step: 14, loss: 20.358125686645508
step: 15, loss: 12.29566764831543
step: 16, loss: 5.123974800109863
step: 17, loss: 0.7438090443611145
step: 18, loss: 0.01452797744423151
step: 19, loss: 1.0787778592202812e-05
step: 20, loss: 7.671380444662645e-06
step: 21, loss: 2.384184938364342e-08
step: 22, loss: 21.03221893310547
step: 23, loss: 27.945518493652344
step: 24, loss: 26.411508560180664
step: 25, loss: 23.422630310058594
ste

step: 227, loss: 3.8068840503692627
step: 228, loss: 3.403703212738037
step: 229, loss: 3.389796733856201
step: 230, loss: 3.326944351196289
step: 231, loss: 3.270170211791992
step: 232, loss: 3.1929895877838135
step: 233, loss: 3.232067823410034
step: 234, loss: 3.0102570056915283
step: 235, loss: 3.5831003189086914
step: 236, loss: 3.025265693664551
step: 237, loss: 3.2527036666870117
step: 238, loss: 2.8983349800109863
step: 239, loss: 2.850038528442383
step: 240, loss: 2.7265963554382324
step: 241, loss: 2.556800127029419
step: 242, loss: 2.666511297225952
step: 243, loss: 2.7543182373046875
step: 244, loss: 2.719313144683838
step: 245, loss: 2.4614033699035645
step: 246, loss: 2.3685076236724854
step: 247, loss: 2.9488136768341064
step: 248, loss: 2.673924207687378
step: 249, loss: 2.2199668884277344
step: 250, loss: 2.404661178588867
step: 251, loss: 2.181959867477417
step: 252, loss: 2.3317389488220215
step: 253, loss: 2.13496994972229
step: 254, loss: 2.0438971519470215
step: 2

step: 456, loss: 0.11112624406814575
step: 457, loss: 0.16519871354103088
step: 458, loss: 0.3228713870048523
step: 459, loss: 0.13709202408790588
step: 460, loss: 0.16988694667816162
step: 461, loss: 0.14190511405467987
step: 462, loss: 0.41463935375213623
step: 463, loss: 0.08421148359775543
step: 464, loss: 0.25556597113609314
step: 465, loss: 0.14804629981517792
step: 466, loss: 0.22594095766544342
step: 467, loss: 0.1137775331735611
step: 468, loss: 0.19310332834720612
step: 469, loss: 0.2133554220199585
step: 470, loss: 0.18512466549873352
step: 471, loss: 0.5359179973602295
step: 472, loss: 0.17072229087352753
step: 473, loss: 0.1379334032535553
step: 474, loss: 0.28487199544906616
step: 475, loss: 0.17798402905464172
step: 476, loss: 0.12628836929798126
step: 477, loss: 0.17773181200027466
step: 478, loss: 0.21798956394195557
step: 479, loss: 0.2391718626022339
step: 480, loss: 0.2787628769874573
step: 481, loss: 0.148030087351799
step: 482, loss: 0.04380195215344429
step: 483,

step: 680, loss: 0.029345646500587463
step: 681, loss: 0.08022693544626236
step: 682, loss: 0.06573718786239624
step: 683, loss: 0.011145596392452717
step: 684, loss: 0.08723705261945724
step: 685, loss: 0.039333585649728775
step: 686, loss: 0.04381227120757103
step: 687, loss: 0.01012243889272213
step: 688, loss: 0.010051540099084377
step: 689, loss: 0.020644424483180046
step: 690, loss: 0.032692596316337585
step: 691, loss: 0.009752648882567883
step: 692, loss: 0.03246708959341049
step: 693, loss: 0.06193813309073448
step: 694, loss: 0.02907845564186573
step: 695, loss: 0.0392715185880661
step: 696, loss: 0.006971640978008509
step: 697, loss: 0.03165365383028984
step: 698, loss: 0.017996078357100487
step: 699, loss: 0.015065351501107216
step: 700, loss: 0.009882457554340363
step: 701, loss: 0.011308488436043262
step: 702, loss: 0.009653972461819649
step: 703, loss: 0.020215604454278946
step: 704, loss: 0.01863441988825798
step: 705, loss: 0.09620153158903122
step: 706, loss: 0.076685

step: 899, loss: 0.042884450405836105
step: 900, loss: 0.17204569280147552
step: 901, loss: 0.008381057530641556
step: 902, loss: 0.005290484055876732
step: 903, loss: 0.02660476602613926
step: 904, loss: 0.009726817719638348
step: 905, loss: 0.00254641892388463
step: 906, loss: 0.0033882802817970514
step: 907, loss: 0.0489305704832077
step: 908, loss: 0.013551464304327965
step: 909, loss: 0.01788674108684063
step: 910, loss: 0.0024142165202647448
step: 911, loss: 0.0039303540252149105
step: 912, loss: 0.008847295306622982
step: 913, loss: 0.037087880074977875
step: 914, loss: 0.005343422293663025
step: 915, loss: 0.019573910161852837
step: 916, loss: 0.0045377411879599094
step: 917, loss: 0.008019454777240753
step: 918, loss: 0.001844325684942305
step: 919, loss: 0.004656457807868719
step: 920, loss: 0.002317463979125023
step: 921, loss: 0.004444102756679058
step: 922, loss: 0.11221801489591599
step: 923, loss: 0.003538410644978285
step: 924, loss: 0.012986714020371437
step: 925, loss

step: 1112, loss: 0.0200326070189476
step: 1113, loss: 0.001385542913340032
step: 1114, loss: 0.02080102451145649
step: 1115, loss: 0.0017004951369017363
step: 1116, loss: 0.0023973321076482534
step: 1117, loss: 0.030768992379307747
step: 1118, loss: 0.002754118526354432
step: 1119, loss: 0.03287778049707413
step: 1120, loss: 0.017438707873225212
step: 1121, loss: 0.00553040811792016
step: 1122, loss: 0.03655075281858444
step: 1123, loss: 0.02441086806356907
step: 1124, loss: 0.02294185198843479
step: 1125, loss: 0.0016561629017814994
step: 1126, loss: 0.0025229318998754025
step: 1127, loss: 0.0027365682180970907
step: 1128, loss: 0.0015428499318659306
step: 1129, loss: 0.06559186428785324
step: 1130, loss: 0.0034861706662923098
step: 1131, loss: 0.0017665892373770475
step: 1132, loss: 0.0022278290707618
step: 1133, loss: 0.001054846215993166
step: 1134, loss: 0.03508150950074196
step: 1135, loss: 0.13068729639053345
step: 1136, loss: 0.03173399344086647
step: 1137, loss: 0.06519997864

step: 1321, loss: 0.001698166597634554
step: 1322, loss: 0.0010615320643410087
step: 1323, loss: 0.06976860761642456
step: 1324, loss: 0.002934265648946166
step: 1325, loss: 0.004334019031375647
step: 1326, loss: 0.0029400617349892855
step: 1327, loss: 0.0012644014786928892
step: 1328, loss: 0.0006774350185878575
step: 1329, loss: 0.0022243098355829716
step: 1330, loss: 0.0032355489674955606
step: 1331, loss: 0.0011286746012046933
step: 1332, loss: 0.0008080317638814449
step: 1333, loss: 0.0019507366232573986
step: 1334, loss: 0.08593236654996872
step: 1335, loss: 0.0006588103133253753
step: 1336, loss: 0.07418236136436462
step: 1337, loss: 0.0006174702430143952
step: 1338, loss: 0.03533980995416641
step: 1339, loss: 0.0013102262746542692
step: 1340, loss: 0.000928698864299804
step: 1341, loss: 0.0014840241055935621
step: 1342, loss: 0.0011980077251791954
step: 1343, loss: 0.011750426143407822
step: 1344, loss: 0.019468899816274643
step: 1345, loss: 0.0036622313782572746
step: 1346, lo

step: 1528, loss: 0.0010177125222980976
step: 1529, loss: 0.0007720980793237686
step: 1530, loss: 0.0025898509193211794
step: 1531, loss: 0.0011940659023821354
step: 1532, loss: 0.03564450517296791
step: 1533, loss: 0.0015396528178825974
step: 1534, loss: 0.0009013635572046041
step: 1535, loss: 0.000566568342037499
step: 1536, loss: 0.001389933517202735
step: 1537, loss: 0.0004902931395918131
step: 1538, loss: 0.0002609170915093273
step: 1539, loss: 0.0025856492575258017
step: 1540, loss: 0.001151951146312058
step: 1541, loss: 0.020815562456846237
step: 1542, loss: 0.0019288378534838557
step: 1543, loss: 0.0008034247439354658
step: 1544, loss: 0.001980260480195284
step: 1545, loss: 0.0008401434170082211
step: 1546, loss: 0.0009469232754781842
step: 1547, loss: 0.001456920406781137
step: 1548, loss: 0.007915521040558815
step: 1549, loss: 0.031564727425575256
step: 1550, loss: 0.0008058664388954639
step: 1551, loss: 0.001272305496968329
step: 1552, loss: 0.006085550878196955
step: 1553, 

step: 1739, loss: 0.001587103703059256
step: 1740, loss: 0.001233970746397972
step: 1741, loss: 0.03348761051893234
step: 1742, loss: 0.00046690955059602857
step: 1743, loss: 0.0006325588328763843
step: 1744, loss: 0.18094702064990997
step: 1745, loss: 0.00028830597875639796
step: 1746, loss: 0.00025484751677140594
step: 1747, loss: 0.00356145272962749
step: 1748, loss: 0.001969809178262949
step: 1749, loss: 0.0005023680860176682
step: 1750, loss: 0.0007718189735896885
step: 1751, loss: 0.0003662451053969562
step: 1752, loss: 0.009356919676065445
step: 1753, loss: 0.0005278369062580168
step: 1754, loss: 0.00022349029313772917
step: 1755, loss: 0.0017728913808241487
step: 1756, loss: 0.0005654566339217126
step: 1757, loss: 0.006315131206065416
step: 1758, loss: 0.0008524723816663027
step: 1759, loss: 0.025322770699858665
step: 1760, loss: 0.0050970204174518585
step: 1761, loss: 0.0004545058764051646
step: 1762, loss: 0.0012792828492820263
step: 1763, loss: 0.0010145591804757714
step: 17

step: 1945, loss: 0.0006486885249614716
step: 1946, loss: 0.0006495903944596648
step: 1947, loss: 0.0008171597146429121
step: 1948, loss: 0.000508206314407289
step: 1949, loss: 0.0011400885414332151
step: 1950, loss: 0.0007058767951093614
step: 1951, loss: 0.00046683382242918015
step: 1952, loss: 0.002036853926256299
step: 1953, loss: 0.0003881593293044716
step: 1954, loss: 0.0007766684284433722
step: 1955, loss: 0.0004665809392463416
step: 1956, loss: 0.0010541188530623913
step: 1957, loss: 0.023690881207585335
step: 1958, loss: 0.0005678129964508116
step: 1959, loss: 0.0003178217157255858
step: 1960, loss: 0.0036233102437108755
step: 1961, loss: 0.0009896332630887628
step: 1962, loss: 0.0010039546759799123
step: 1963, loss: 0.0021780477836728096
step: 1964, loss: 0.0006071307579986751
step: 1965, loss: 0.0004449106636457145
step: 1966, loss: 0.00037852569948881865
step: 1967, loss: 0.0005995045648887753
step: 1968, loss: 0.0013981149531900883
step: 1969, loss: 0.00036707575782202184


step: 2150, loss: 0.000636063574347645
step: 2151, loss: 0.00044212699867784977
step: 2152, loss: 0.0006939757149666548
step: 2153, loss: 0.0006546130171045661
step: 2154, loss: 0.0006085311179049313
step: 2155, loss: 0.0008610913064330816
step: 2156, loss: 0.0007649746257811785
step: 2157, loss: 0.0002820778754539788
step: 2158, loss: 0.0006422508740797639
step: 2159, loss: 0.0006089616217650473
step: 2160, loss: 0.00024175969883799553
step: 2161, loss: 0.0010292103979736567
step: 2162, loss: 0.000518468557856977
step: 2163, loss: 0.00022540600912179798
step: 2164, loss: 0.0535660944879055
step: 2165, loss: 0.0007287567132152617
step: 2166, loss: 0.0003969747922383249
step: 2167, loss: 0.00012366224837023765
step: 2168, loss: 0.0006556889275088906
step: 2169, loss: 0.0003036729176528752
step: 2170, loss: 0.0003247674903832376
step: 2171, loss: 0.0002542577567510307
step: 2172, loss: 0.0021445625461637974
step: 2173, loss: 0.0005030063330195844
step: 2174, loss: 0.00045091670472174883


step: 2355, loss: 0.00034657909418456256
step: 2356, loss: 0.000255232909694314
step: 2357, loss: 0.003727538511157036
step: 2358, loss: 0.0002089674526359886
step: 2359, loss: 0.0004476085305213928
step: 2360, loss: 0.0012747662840411067
step: 2361, loss: 0.0003515487478580326
step: 2362, loss: 0.00033271111897192895
step: 2363, loss: 0.0004489836865104735
step: 2364, loss: 0.0008022547699511051
step: 2365, loss: 0.0006456378614529967
step: 2366, loss: 0.00041735198465175927
step: 2367, loss: 0.00016703011351637542
step: 2368, loss: 0.00026848362176679075
step: 2369, loss: 0.0003002247540280223
step: 2370, loss: 0.00012810347834601998
step: 2371, loss: 0.0003523631894495338
step: 2372, loss: 7.614708010805771e-05
step: 2373, loss: 0.00020278450392652303
step: 2374, loss: 0.00022514889133162796
step: 2375, loss: 0.0002241062611574307
step: 2376, loss: 0.00035128119634464383
step: 2377, loss: 0.0002817421336658299
step: 2378, loss: 0.0002270305121783167
step: 2379, loss: 0.0004416437877

step: 2559, loss: 0.00021861871937289834
step: 2560, loss: 0.0004891320713795722
step: 2561, loss: 0.0007521420484408736
step: 2562, loss: 0.00013467270764522254
step: 2563, loss: 0.00021593373094219714
step: 2564, loss: 0.00013893033610656857
step: 2565, loss: 0.00021371433103922755
step: 2566, loss: 0.0002362833038205281
step: 2567, loss: 0.00019573904864955693
step: 2568, loss: 0.0006214980967342854
step: 2569, loss: 0.00015584201901219785
step: 2570, loss: 0.00022658945817966014
step: 2571, loss: 0.0004140192177146673
step: 2572, loss: 0.00017305112851317972
step: 2573, loss: 0.00019195425556972623
step: 2574, loss: 0.0021482028532773256
step: 2575, loss: 0.00020189976203255355
step: 2576, loss: 0.02580142579972744
step: 2577, loss: 0.0003304421261418611
step: 2578, loss: 0.001070054597221315
step: 2579, loss: 0.0003430454817134887
step: 2580, loss: 0.0003316180082038045
step: 2581, loss: 0.0001503563835285604
step: 2582, loss: 0.000518892309628427
step: 2583, loss: 0.0002953631628

step: 2763, loss: 0.00018971279496327043
step: 2764, loss: 7.605714199598879e-05


KeyboardInterrupt: 

In [7]:

# evaluate model
# ??? different vocab_size?
test_data_reader = DataReader(
    data_path = './datasets/20news-test-tfidf.txt',
    batch_size = 50,
    vocab_size = vocab_size
)
with tf.Session() as sess:
    epoch = 9
    
    trainable_variables = tf.trainable_variables()
    for variable in trainable_variables:
        saved_value = restore_parameters(variable.name, epoch)
        assign_op = variable.assign(saved_value)
        sess.run(assign_op)
        
    num_true_preds = 0
    while True:
        test_data, test_labels = test_data_reader.next_batch()
        test_plabels_eval = sess.run(
            predicted_labels,
            feed_dict={
                mlp._X: test_data,
                mlp._real_Y: test_labels
            }
        )
        matches = np.equal(test_plabels_eval, test_labels)
        num_true_preds += np.sum(matches.astype(float))
        
        if test_data_reader._batch_id == 0:
            break
            
    print('Epoch: ', epoch)
    print('Accuracy on test data: ', num_true_preds / len(test_data_reader._data))

Epoch:  9
Accuracy on test data:  0.7469463621879979
