In [1]:
%matplotlib inline

import sys

sys.path.append("..")

In [4]:
import os, pickle, torch

from settings import ROOT_DIR

resume = '19-05-08T20-18'
start_epoch = 1

checkpoint_directory = os.path.join(ROOT_DIR, 'static', 'checkpoints', 'triplet_cnn', resume)

net = torch.load(os.path.join(checkpoint_directory, '_net_%s.pth' % start_epoch))
net

TripletNetwork(
  (embedding_network): ConvolutionalNetwork(
    (convolution_1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (convolution_2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (convolution_3): Conv2d(16, 20, kernel_size=(4, 4), stride=(1, 1))
    (fully_connected_1): Linear(in_features=16820, out_features=15000, bias=True)
    (fully_connected_2): Linear(in_features=15000, out_features=1200, bias=True)
    (fully_connected_3): Linear(in_features=1200, out_features=125, bias=True)
  )
)

In [5]:
from torch.utils.data.dataloader import default_collate

from src.datasets import get_dataset

dataset_name = 'sketchy_test_photos_triplets'
batch_size = 1
workers = 4
collate = default_collate

dataset = get_dataset(dataset_name)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True,
    num_workers=workers, collate_fn=collate
)

real_batch = next(iter(data_loader))

query_image = real_batch[0][0]
query_class = real_batch[0][1]

In [6]:
query_vector = net.embedding_network(query_image)
query_vector

tensor([[-4.8405, -4.8691, -4.7285, -4.9261, -4.8282, -4.8584, -4.9078, -4.8530,
         -4.7228, -4.8582, -4.7934, -4.7548, -4.8168, -4.7971, -4.9568, -4.8788,
         -4.8110, -4.9251, -4.8131, -4.9528, -4.8160, -4.7920, -4.8033, -4.9407,
         -4.8070, -4.7939, -4.8498, -4.9279, -4.7366, -4.7871, -4.7383, -4.8220,
         -4.7530, -4.7975, -4.8418, -4.7156, -4.8193, -4.7632, -4.8207, -4.7986,
         -4.7419, -4.8491, -4.8751, -4.8214, -4.7689, -4.8845, -4.7574, -4.7301,
         -4.8514, -4.8190, -4.7895, -4.9910, -4.7128, -4.8877, -4.8052, -4.8583,
         -4.7615, -4.9760, -4.9106, -4.8773, -4.8694, -4.8280, -4.7208, -4.8753,
         -4.8422, -4.8293, -4.8958, -4.7478, -4.8037, -4.7191, -4.8631, -4.8977,
         -4.8559, -4.7383, -4.9056, -4.8241, -4.7967, -4.7940, -4.8386, -4.8291,
         -4.8083, -4.7643, -4.7819, -4.8458, -4.7980, -4.7950, -4.9019, -4.8353,
         -4.7793, -4.8258, -4.8760, -4.9048, -4.8730, -4.9390, -4.9217, -4.8800,
         -4.9042, -4.7426, -

In [7]:
from src.models.convolutional_network import ConvolutionalNetwork

convnet = ConvolutionalNetwork()
print(convnet)

ConvolutionalNetwork(
  (convolution_1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (convolution_2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (convolution_3): Conv2d(16, 20, kernel_size=(4, 4), stride=(1, 1))
  (fully_connected_1): Linear(in_features=16820, out_features=15000, bias=True)
  (fully_connected_2): Linear(in_features=15000, out_features=1200, bias=True)
  (fully_connected_3): Linear(in_features=1200, out_features=125, bias=True)
)


In [8]:
vector = convnet(query_image)
vector

tensor([[-4.8118, -4.8128, -4.8104, -4.8181, -4.7960, -4.8277, -4.8574, -4.8383,
         -4.8347, -4.8058, -4.8239, -4.8313, -4.8592, -4.8107, -4.8338, -4.8500,
         -4.8273, -4.8491, -4.8414, -4.8232, -4.8450, -4.8199, -4.8098, -4.8470,
         -4.8533, -4.8339, -4.8438, -4.8270, -4.8265, -4.8418, -4.8271, -4.8430,
         -4.8264, -4.7970, -4.8502, -4.8211, -4.8189, -4.8087, -4.8228, -4.8006,
         -4.8204, -4.8408, -4.8583, -4.8308, -4.8350, -4.8416, -4.8437, -4.8100,
         -4.8268, -4.8064, -4.8228, -4.8432, -4.8121, -4.8323, -4.8584, -4.8114,
         -4.8311, -4.8014, -4.7982, -4.8286, -4.8425, -4.7974, -4.8456, -4.8547,
         -4.7996, -4.8247, -4.8004, -4.8343, -4.8485, -4.8541, -4.8573, -4.8163,
         -4.8478, -4.8182, -4.8096, -4.8444, -4.8586, -4.8369, -4.8224, -4.8139,
         -4.8097, -4.8360, -4.8020, -4.8419, -4.8052, -4.8041, -4.8172, -4.8002,
         -4.8381, -4.8490, -4.8276, -4.8191, -4.8034, -4.8024, -4.8487, -4.8125,
         -4.8034, -4.8192, -

---

In [None]:
-