Skip to content

Commit

Permalink
t-SNE
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Oct 30, 2018
1 parent daf0004 commit 75c1ad8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion generate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def valid(self):

if __name__ == '__main__':

gpus = [3]
gpus = [2,3]
batch_shape = (128, 3, 32, 32)
image_channel = batch_shape[1]
nepochs = 200
Expand Down
18 changes: 18 additions & 0 deletions jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class Loger(object):
"""this is a log recorder.
"""

def __init__(self, logdir="log"):
self.logdir = logdir
self.regist_list = []
Expand Down Expand Up @@ -192,6 +193,7 @@ class Watcher(object):
"""this is a params and images watcher
"""

def __init__(self, logdir, mode="L"):
self.logdir = logdir
self.writer = SummaryWriter(log_dir=logdir)
Expand Down Expand Up @@ -241,6 +243,22 @@ def image(self, img_tensors, global_step, tag="Train/input", grid_size=(3, 1), s
filename = "%s/plots/%s/E%03d.png" % (self.logdir, tag, global_step)
img.save(filename)

def embedding(self, mat, label_img=None, label=None, global_step=None, tag="embedding"):
""" Show PCA, t-SNE of `mat` on tensorboard
:param mat: An img tensor with shape of (N, C, H, W)
:param label_img: Label img on each data point.
:param label: Label of each img. It will convert to str.
:param global_step: Img step label.
:param tag: Tag of this plot.
"""
# images = dataset.train_data[:amount]
# images = dataset.train_data[:amount]

samples = len(mat)
features = mat.view(samples, -1)
self.writer.add_embedding(features, metadata=label, label_img=label_img, global_step=global_step, tag=tag)

def set_training_progress_images(self, img_tensors, grid_size=(3, 1)):
assert len(img_tensors.size()) == 4, "img_tensors rank should be 4, got %d instead" % len(img_tensors.size())
rows, columns = grid_size[0], grid_size[1]
Expand Down

0 comments on commit 75c1ad8

Please sign in to comment.