Skip to content

Commit

Permalink
patch alexnet
Browse files Browse the repository at this point in the history
  • Loading branch information
christiansafka committed Nov 29, 2018
1 parent f1072af commit 4a48301
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions img_to_vec.py
Expand Up @@ -15,6 +15,8 @@ def __init__(self, cuda=False, model='resnet-18', layer='default', layer_output_
"""
self.device = torch.device("cuda" if cuda else "cpu")
self.layer_output_size = layer_output_size
self.model_name = model

self.model, self.extraction_layer = self._get_model_and_layer(model, layer)

self.model = self.model.to(self.device)
Expand All @@ -34,7 +36,10 @@ def get_vec(self, img, tensor=False):
"""
image = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0).to(self.device)

my_embedding = torch.zeros(1, self.layer_output_size, 1, 1)
if self.model_name == 'alexnet':
my_embedding = torch.zeros(1, self.layer_output_size)
else:
my_embedding = torch.zeros(1, self.layer_output_size, 1, 1)

def copy_data(m, i, o):
my_embedding.copy_(o.data)
Expand All @@ -46,7 +51,10 @@ def copy_data(m, i, o):
if tensor:
return my_embedding
else:
return my_embedding.numpy()[0, :, 0, 0]
if self.model_name == 'alexnet':
return my_embedding.numpy()[0, :]
else:
return my_embedding.numpy()[0, :, 0, 0]

def _get_model_and_layer(self, model_name, layer):
""" Internal method for getting layer from model
Expand Down

0 comments on commit 4a48301

Please sign in to comment.