This repository has been archived by the owner on Nov 5, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
/
compute_one.py
59 lines (47 loc) · 1.96 KB
/
compute_one.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute Wasserstein distance between two random subsets of CIFAR10.
Note: comparing two fixed sets is a sanity check, not the target use case.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import tensorflow as tf
from dataset import Dataset
from wasserstein import Wasserstein
tf.flags.DEFINE_string('filepattern', '/tmp/cifar10/cifar_train_class_%d.pic',
'Filepattern from which to read the dataset.')
tf.flags.DEFINE_integer('batch_size', 1000, 'Batch size of generator.')
tf.flags.DEFINE_integer('loss_steps', 50, 'Number of optimization steps.')
FLAGS = tf.flags.FLAGS
def main(unused_argv):
# tf.logging.set_verbosity(tf.logging.INFO)
# load two copies of the dataset
print('Loading datasets...')
subset1 = Dataset(bs=FLAGS.batch_size, filepattern=FLAGS.filepattern)
subset2 = Dataset(bs=FLAGS.batch_size, filepattern=FLAGS.filepattern)
print('Computing Wasserstein distance...')
with tf.Graph().as_default():
# compute Wasserstein distance between two sets of examples
wasserstein = Wasserstein(subset1, subset2)
loss = wasserstein.dist(C=.1, nsteps=FLAGS.loss_steps)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
res = sess.run(loss)
print('result: %f\n' % res)
if __name__ == '__main__':
tf.app.run(main)