In [1]:
import numpy as np
import torch
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from Models.selective_sequential import *
from Loss.triplet_regularized import *
from session import *
from LR_Schedule.cyclical import Cyclical
from LR_Schedule.cos_anneal import CosAnneal
from LR_Schedule.lr_find import lr_find
from callbacks import *
from validation import *
import Datasets.ImageData as ImageData
from Transforms.ImageTransforms import *
import util
from session import LossMeter, EvalModel
from Layers.flatten import Flatten
from torch.utils.tensorboard import SummaryWriter

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
%load_ext autoreload
%autoreload 2

torch.cuda.set_device(0); torch.backends.cudnn.benchmark=True;

    Found GPU0 GeForce GTX 770 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability that we support is 3.5.
    


In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = datasets.MNIST('/media/drake/MX500/Datasets/mnist/train', download=True, train=True, transform=transform)
partial_trainset = torch.utils.data.dataset.Subset(trainset, np.arange(500))

valset = datasets.MNIST('/media/drake/MX500/Datasets/mnist/test', download=True, train=False, transform=transform)
partial_valset = torch.utils.data.dataset.Subset(valset, np.arange(500))

trainloader = torch.utils.data.DataLoader(partial_trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(partial_valset, batch_size=64, shuffle=False)

In [4]:
select = ['max1', 'act1', 'out']

def make_model():
    return SelectiveSequential(
    select,
    {'conv64': nn.Conv2d(1, 64, kernel_size=5, padding=2),
     'act64': nn.ReLU(True),
     
     'max1': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv192': nn.Conv2d(64, 192, kernel_size=5, padding=2),
     'act192': nn.ReLU(True),
    
     'max2': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv384': nn.Conv2d(192, 384, kernel_size=3, padding=1),
     'act384': nn.ReLU(True),
     
     'conv256a': nn.Conv2d(384, 256, kernel_size=3, padding=1),
     'act256a': nn.ReLU(True),
     
     'conv256b': nn.Conv2d(256, 256, kernel_size=3, padding=1),
     'act256b': nn.ReLU(True),
     
     'max3': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'flatten': Flatten(),
     'fc1': nn.Linear(3 * 3 * 256, 512),
     'act1': nn.ReLU(True),
     'fc2': nn.Linear(512, 512),
     'act2': nn.ReLU(True),
     'out': nn.Linear(512, 10)})

model = make_model()

In [5]:
criterion = TripletRegularizedMultiMarginLoss(.5, .5, select)
sess = Session(model, criterion, optim.Adam, 1e-4)

In [6]:
num_epochs = 50
validator = EmbeddingSpaceValidator(valloader, select, CustomOneHotAccuracy, 
                                    model_file="./triplet-reg.ckpt.tar")
lr_scheduler = CosAnneal(len(trainloader)*50, T_mult=1, lr_min=1e-6)
schedule = TrainingSchedule(trainloader, [lr_scheduler, validator])

In [None]:
sess.train(schedule, 50)

In [None]:
validator.plot()

In [None]:
sess.load("./triplet-reg.ckpt.tar")

In [None]:
total_valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=False)
total_validator = EmbeddingSpaceValidator(total_valloader, [], CustomOneHotAccuracy)

total_validator.run(sess)

In [None]:
np.max(total_validator.val_accuracies), "Best accuracy without reg"

In [10]:
def train(lmbda, train_loader, val_loader):
    print(f"Training: {lmbda}")
    num_epochs = 20
    validator = EmbeddingSpaceValidator(val_loader, [], CustomOneHotAccuracy)
    lr_scheduler = CosAnneal(len(train_loader)*num_epochs, T_mult=1, lr_min=1e-6)
    schedule = TrainingSchedule(train_loader, [lr_scheduler, validator])
    criterion = TripletRegularizedMultiMarginLoss(lmbda, .5, select)
    sess = Session(make_model(), criterion, optim.Adam, 1e-4)
    sess.train(schedule, num_epochs)
    return np.max(validator.val_accuracies)

losses = {}

def search(lower, middle, upper, max_interval = .01, lower_acc=None, upper_acc=None):
    if (upper-lower < max_interval): return

    if lower_acc == None: 
        lower_acc = train(lower, trainloader, valloader)
        losses[lower] = lower_acc
    
    middle_acc = train(middle, trainloader, valloader)
    losses[middle] = middle_acc
    
    if upper_acc == None: 
        upper_acc = train(upper, trainloader, valloader)
        losses[upper] = upper_acc
        
    lower_mean = (lower_acc + middle_acc) / 2
    upper_mean = (upper_acc + middle_acc) / 2
    
    if lower_acc > upper_acc:
        search(lower, (middle-lower)/2, middle, max_interval, lower_acc, middle_acc)
    else:
        search(middle, (upper-middle)/2, upper, max_interval, middle_acc, upper_acc)

In [None]:
search(0, .5, 1)

Training: 0


HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.134 
train loss:  0.8985  train BCE :  0.8986 
valid loss:  0.8934  valid BCE :  0.8934


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.144 
train loss:  0.8926  train BCE :  0.8881 
valid loss:  0.8763  valid BCE :  0.8763


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.204 
train loss:  0.8729  train BCE :  0.8428 
valid loss:  0.7792  valid BCE :  0.7792


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.254 
train loss:  0.8148  train BCE :  0.6927 
valid loss:  0.6637  valid BCE :  0.6637


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.43 
train loss:  0.7436  train BCE :  0.5572 
valid loss:  0.5072  valid BCE :  0.5072


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.612 
train loss:  0.6481  train BCE :  0.3496 
valid loss:  0.3118  valid BCE :  0.3118


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.662 
train loss:  0.5494  train BCE :  0.2022 
valid loss:  0.2185  valid BCE :  0.2185


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.674 
train loss:  0.4671  train BCE :  0.149 
valid loss:  0.1828  valid BCE :  0.1828


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.668 
train loss:  0.4002  train BCE :  0.1228 
valid loss:  0.1891  valid BCE :  0.1891


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.71 
train loss:  0.3466  train BCE :  0.1125 
valid loss:  0.1583  valid BCE :  0.1583


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.718 
train loss:  0.3035  train BCE :  0.1078 
valid loss:  0.141  valid BCE :  0.141


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.766 
train loss:  0.2654  train BCE :  0.0853 
valid loss:  0.1241  valid BCE :  0.1241


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.764 
train loss:  0.2339  train BCE :  0.0787 
valid loss:  0.1201  valid BCE :  0.1201


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.76 
train loss:  0.2077  train BCE :  0.0747 
valid loss:  0.1183  valid BCE :  0.1183


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.78 
train loss:  0.1851  train BCE :  0.07 
valid loss:  0.1126  valid BCE :  0.1126


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.782 
train loss:  0.1663  train BCE :  0.0683 
valid loss:  0.1106  valid BCE :  0.1106


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.776 
train loss:  0.1503  train BCE :  0.0664 
valid loss:  0.1109  valid BCE :  0.1109


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.778 
train loss:  0.1369  train BCE :  0.0653 
valid loss:  0.1107  valid BCE :  0.1107


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.782 
train loss:  0.1255  train BCE :  0.0648 
valid loss:  0.1106  valid BCE :  0.1106


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.784 
train loss:  0.1161  train BCE :  0.0644 
valid loss:  0.1103  valid BCE :  0.1103

Training: 0.5


HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.21 
train loss:  2.5762  train BCE :  0.897 
valid loss:  2.6089  valid BCE :  0.8916


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.218 
train loss:  2.5654  train BCE :  0.8752 
valid loss:  2.5309  valid BCE :  0.8643


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.19 
train loss:  2.5378  train BCE :  0.8146 
valid loss:  2.5176  valid BCE :  0.7892


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.344 
train loss:  2.5153  train BCE :  0.7135 
valid loss:  2.4288  valid BCE :  0.6987


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.314 
train loss:  2.4523  train BCE :  0.5943 
valid loss:  2.3441  valid BCE :  0.5704


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.668 
train loss:  2.3919  train BCE :  0.448 
valid loss:  2.2605  valid BCE :  0.4222


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.726 
train loss:  2.3156  train BCE :  0.3248 
valid loss:  2.1668  valid BCE :  0.3226


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.78 
train loss:  2.2514  train BCE :  0.2409 
valid loss:  2.1049  valid BCE :  0.2674


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.852 
train loss:  2.179  train BCE :  0.1868 
valid loss:  2.0587  valid BCE :  0.2319


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.878 
train loss:  2.1205  train BCE :  0.1615 
valid loss:  2.0105  valid BCE :  0.1943


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.886 
train loss:  2.0587  train BCE :  0.1302 
valid loss:  1.9834  valid BCE :  0.1777


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.916 
train loss:  2.0057  train BCE :  0.1127 
valid loss:  1.9599  valid BCE :  0.1682


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.916 
train loss:  1.9431  train BCE :  0.106 
valid loss:  1.948  valid BCE :  0.1697


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.914 
train loss:  1.9026  train BCE :  0.1034 
valid loss:  1.9157  valid BCE :  0.159


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.908 
train loss:  1.8616  train BCE :  0.0989 
valid loss:  1.9137  valid BCE :  0.1652


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.932 
train loss:  1.8337  train BCE :  0.095 
valid loss:  1.9006  valid BCE :  0.1546


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.924 
train loss:  1.8044  train BCE :  0.0921 
valid loss:  1.8979  valid BCE :  0.1563


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.93 
train loss:  1.7715  train BCE :  0.0912 
valid loss:  1.8952  valid BCE :  0.1533


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.934 
train loss:  1.745  train BCE :  0.0896 
valid loss:  1.8922  valid BCE :  0.1528


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.938 
train loss:  1.7224  train BCE :  0.0891 
valid loss:  1.8915  valid BCE :  0.1525

Training: 1


HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.084 
train loss:  4.3669  train BCE :  0.8971 
valid loss:  4.3355  valid BCE :  0.8943


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.188 
train loss:  4.3497  train BCE :  0.8784 
valid loss:  4.3024  valid BCE :  0.8641


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.18 
train loss:  4.3346  train BCE :  0.8203 
valid loss:  4.2551  valid BCE :  0.809


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.142 
train loss:  4.2876  train BCE :  0.7605 
valid loss:  4.1861  valid BCE :  0.756


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.172 
train loss:  4.2253  train BCE :  0.699 
valid loss:  4.1368  valid BCE :  0.6987


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.15 
train loss:  4.187  train BCE :  0.6323 
valid loss:  4.0726  valid BCE :  0.6433


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.414 
train loss:  4.1151  train BCE :  0.5612 
valid loss:  4.0185  valid BCE :  0.5633


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.388 
train loss:  4.0463  train BCE :  0.485 
valid loss:  3.9147  valid BCE :  0.513


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.566 
train loss:  3.9591  train BCE :  0.4205 
valid loss:  3.8731  valid BCE :  0.4398


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.568 
train loss:  3.8841  train BCE :  0.3574 
valid loss:  3.8145  valid BCE :  0.4099


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.648 
train loss:  3.7985  train BCE :  0.3346 
valid loss:  3.7768  valid BCE :  0.3814


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.75 
train loss:  3.7283  train BCE :  0.3001 
valid loss:  3.7457  valid BCE :  0.3477


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.692 
train loss:  3.6317  train BCE :  0.2754 
valid loss:  3.7225  valid BCE :  0.3393


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.744 
train loss:  3.5753  train BCE :  0.265 
valid loss:  3.6866  valid BCE :  0.3314


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.752 
train loss:  3.5256  train BCE :  0.2496 
valid loss:  3.6862  valid BCE :  0.3111


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.758 
train loss:  3.4763  train BCE :  0.2335 
valid loss:  3.6888  valid BCE :  0.3002


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.772 
train loss:  3.4233  train BCE :  0.2266 
valid loss:  3.678  valid BCE :  0.298


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.764 
train loss:  3.3815  train BCE :  0.2238 
valid loss:  3.6718  valid BCE :  0.2972


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.764 
train loss:  3.339  train BCE :  0.2239 
valid loss:  3.665  valid BCE :  0.2976


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.774 
train loss:  3.3063  train BCE :  0.2234 
valid loss:  3.6646  valid BCE :  0.2971

Training: 0.25


HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.194 
train loss:  1.7276  train BCE :  0.9 
valid loss:  1.6918  valid BCE :  0.8937


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.144 
train loss:  1.6906  train BCE :  0.8782 
valid loss:  1.6568  valid BCE :  0.8459


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.536 
train loss:  1.6608  train BCE :  0.788 
valid loss:  1.6008  valid BCE :  0.7402


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.456 
train loss:  1.619  train BCE :  0.656 
valid loss:  1.5066  valid BCE :  0.6207


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.672 
train loss:  1.5564  train BCE :  0.4907 
valid loss:  1.3858  valid BCE :  0.461


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.654 
train loss:  1.4837  train BCE :  0.3475 
valid loss:  1.2867  valid BCE :  0.3311


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.79 
train loss:  1.4072  train BCE :  0.2336 
valid loss:  1.2041  valid BCE :  0.2305


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.828 
train loss:  1.3304  train BCE :  0.1615 
valid loss:  1.1443  valid BCE :  0.1847


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.818 
train loss:  1.2601  train BCE :  0.1293 
valid loss:  1.1183  valid BCE :  0.1801


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.872 
train loss:  1.2062  train BCE :  0.1048 
valid loss:  1.0876  valid BCE :  0.1456


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.832 
train loss:  1.1648  train BCE :  0.0888 
valid loss:  1.0707  valid BCE :  0.1573


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.872 
train loss:  1.1198  train BCE :  0.0981 
valid loss:  1.0419  valid BCE :  0.1399


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.888 
train loss:  1.0812  train BCE :  0.0864 
valid loss:  1.0349  valid BCE :  0.1347


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.892 
train loss:  1.0489  train BCE :  0.0725 
valid loss:  1.0182  valid BCE :  0.1247


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.906 
train loss:  1.0187  train BCE :  0.068 
valid loss:  1.0104  valid BCE :  0.1208


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.91 
train loss:  0.9929  train BCE :  0.0646 
valid loss:  1.006  valid BCE :  0.1186


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.91 
train loss:  0.9684  train BCE :  0.0611 
valid loss:  1.0052  valid BCE :  0.1175


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.914 
train loss:  0.9461  train BCE :  0.06 
valid loss:  1.0026  valid BCE :  0.1163


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.914 
train loss:  0.9272  train BCE :  0.0594 
valid loss:  1.0022  valid BCE :  0.1154


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.916 
train loss:  0.9101  train BCE :  0.0588 
valid loss:  1.0017  valid BCE :  0.1153

Training: 0.125


HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.104 
train loss:  1.3473  train BCE :  0.8989 
valid loss:  1.3363  valid BCE :  0.893


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.198 
train loss:  1.3249  train BCE :  0.8719 
valid loss:  1.2932  valid BCE :  0.8454


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.218 
train loss:  1.2856  train BCE :  0.7679 
valid loss:  1.1763  valid BCE :  0.7022


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.662 
train loss:  1.2062  train BCE :  0.5581 
valid loss:  0.9996  valid BCE :  0.4639


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.624 
train loss:  1.1225  train BCE :  0.3267 
valid loss:  0.8885  valid BCE :  0.304


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.752 
train loss:  1.0377  train BCE :  0.1993 
valid loss:  0.8113  valid BCE :  0.2279


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.786 
train loss:  0.9659  train BCE :  0.1544 
valid loss:  0.7536  valid BCE :  0.1912


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.8 
train loss:  0.9024  train BCE :  0.1294 
valid loss:  0.7064  valid BCE :  0.163


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.792 
train loss:  0.846  train BCE :  0.1071 
valid loss:  0.6857  valid BCE :  0.151


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.866 
train loss:  0.7994  train BCE :  0.0953 
valid loss:  0.6628  valid BCE :  0.134


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.848 
train loss:  0.7581  train BCE :  0.0831 
valid loss:  0.6498  valid BCE :  0.1324


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.882 
train loss:  0.7214  train BCE :  0.0767 
valid loss:  0.6308  valid BCE :  0.1201


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.878 
train loss:  0.6913  train BCE :  0.0674 
valid loss:  0.6263  valid BCE :  0.1128


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.888 
train loss:  0.664  train BCE :  0.0641 
valid loss:  0.6228  valid BCE :  0.1117


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.896 
train loss:  0.6409  train BCE :  0.0611 
valid loss:  0.616  valid BCE :  0.1109


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.89 
train loss:  0.6212  train BCE :  0.059 
valid loss:  0.6109  valid BCE :  0.108


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.898 
train loss:  0.6045  train BCE :  0.058 
valid loss:  0.6081  valid BCE :  0.1075


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.9 
train loss:  0.5903  train BCE :  0.0564 
valid loss:  0.6081  valid BCE :  0.1077


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.902 
train loss:  0.5787  train BCE :  0.0562 
valid loss:  0.6078  valid BCE :  0.1077


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.9 
train loss:  0.5684  train BCE :  0.0559 
valid loss:  0.6075  valid BCE :  0.1077

Training: 0.1875


HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.108 
train loss:  1.5862  train BCE :  0.8987 
valid loss:  1.5553  valid BCE :  0.8912


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.28 
train loss:  1.5526  train BCE :  0.8675 
valid loss:  1.5088  valid BCE :  0.8429


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.36 
train loss:  1.5203  train BCE :  0.7679 
valid loss:  1.4355  valid BCE :  0.72


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.496 
train loss:  1.4506  train BCE :  0.5891 
valid loss:  1.2747  valid BCE :  0.5092


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.594 
train loss:  1.3672  train BCE :  0.3607 
valid loss:  1.1468  valid BCE :  0.346


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.768 
train loss:  1.2836  train BCE :  0.242 
valid loss:  1.0485  valid BCE :  0.2456


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.786 
train loss:  1.2096  train BCE :  0.1843 
valid loss:  0.991  valid BCE :  0.2027


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.852 
train loss:  1.1402  train BCE :  0.1315 
valid loss:  0.9351  valid BCE :  0.1647


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.862 
train loss:  1.0768  train BCE :  0.1072 
valid loss:  0.903  valid BCE :  0.1477


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.866 
train loss:  1.0212  train BCE :  0.0962 
valid loss:  0.8918  valid BCE :  0.1309


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.88 
train loss:  0.9768  train BCE :  0.0807 
valid loss:  0.8739  valid BCE :  0.1227


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.878 
train loss:  0.9368  train BCE :  0.0701 
valid loss:  0.853  valid BCE :  0.1208


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.904 
train loss:  0.9057  train BCE :  0.0662 
valid loss:  0.8392  valid BCE :  0.114


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.904 
train loss:  0.8746  train BCE :  0.0592 
valid loss:  0.833  valid BCE :  0.11


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.906 
train loss:  0.8486  train BCE :  0.0556 
valid loss:  0.8261  valid BCE :  0.1061


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.916 
train loss:  0.8258  train BCE :  0.054 
valid loss:  0.8202  valid BCE :  0.1036


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.91 
train loss:  0.8037  train BCE :  0.0519 
valid loss:  0.8162  valid BCE :  0.1031


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…



val accuracy:  0.912 
train loss:  0.7882  train BCE :  0.051 
valid loss:  0.815  valid BCE :  0.1024


HBox(children=(IntProgress(value=0, description='Steps', max=8, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Validating', max=8, style=ProgressStyle(description_width='in…

In [9]:
losses

{0.01: 0.612,
 0.5: 0.582,
 1: 0.334,
 0.245: 0.534,
 0.1175: 0.646,
 0.05375: 0.63,
 0.031875: 0.702,
 0.042812499999999996: 0.602,
 0.005468749999999998: 0.482}

In [None]:
visualization_set = torch.utils.data.dataset.Subset(valset, np.arange(500))
dataloader = torch.utils.data.DataLoader(visualization_set, batch_size=64, shuffle=False)

tensorboard_embeddings(model, ['max1'], 
                       dataloader, 
                       valset.targets[:500], 
                       1.0 - valset.data[:500].reshape(-1, 1, 28, 28) / 255.0, 
                       './mnist_tripletreg')

tensorboard_embeddings(model, ['max2'], 
                       dataloader, 
                       valset.targets[:500], 
                       1.0 - valset.data[:500].reshape(-1, 1, 28, 28) / 255.0, 
                       './mnist_tripletreg')

tensorboard_embeddings(model, ['max3'], 
                       dataloader, 
                       valset.targets[:500], 
                       1.0 - valset.data[:500].reshape(-1, 1, 28, 28) / 255.0, 
                       './mnist_tripletreg')

tensorboard_embeddings(model, ['act1'], 
                       dataloader, 
                       valset.targets[:500], 
                       1.0 - valset.data[:500].reshape(-1, 1, 28, 28) / 255.0, 
                       './mnist_tripletreg')

tensorboard_embeddings(model, ['act2'], 
                       dataloader, 
                       valset.targets[:500], 
                       1.0 - valset.data[:500].reshape(-1, 1, 28, 28) / 255.0, 
                       './mnist_tripletreg')

tensorboard_embeddings(model, ['out'], 
                       dataloader, 
                       valset.targets[:500], 
                       1.0 - valset.data[:500].reshape(-1, 1, 28, 28) / 255.0, 
                       './mnist_tripletreg')