Skip to content

Commit

Permalink
Fix map_location=cpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Feb 7, 2021
1 parent 7b4b480 commit 9ec7b0d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions backbone.py
Expand Up @@ -129,9 +129,9 @@ def forward(self, x, partial:bool=False):
outs.append(x)
return outs

def init_backbone(self, path, map_location=None):
def init_backbone(self, path):
""" Initializes the backbone weights for training. """
state_dict = torch.load(path, map_location=map_location)
state_dict = torch.load(path, map_location='cpu')

# Replace layer1 -> layers.0 etc.
keys = list(state_dict)
Expand All @@ -154,7 +154,7 @@ class ResNetBackboneGN(ResNetBackbone):
def __init__(self, layers, num_groups=32):
super().__init__(layers, norm_layer=lambda x: nn.GroupNorm(num_groups, x))

def init_backbone(self, path, map_location=None):
def init_backbone(self, path):
""" The path here comes from detectron. So we load it differently. """
with open(path, 'rb') as f:
state_dict = pickle.load(f, encoding='latin1') # From the detectron source
Expand Down Expand Up @@ -301,10 +301,10 @@ def add_layer(self, conv_channels=1024, stride=2, depth=1, block=DarkNetBlock):
""" Add a downsample layer to the backbone as per what SSD does. """
self._make_layer(block, conv_channels // block.expansion, num_blocks=depth, stride=stride)

def init_backbone(self, path, map_location=None):
def init_backbone(self, path):
""" Initializes the backbone weights for training. """
# Note: Using strict=False is berry scary. Triple check this.
self.load_state_dict(torch.load(path), map_location, strict=False)
self.load_state_dict(torch.load(path, map_location='cpu'), strict=False)



Expand Down Expand Up @@ -407,9 +407,9 @@ def transform_key(self, k):
layerIdx = self.state_dict_lookup[int(vals[0])]
return 'layers.%s.%s' % (layerIdx, vals[1])

def init_backbone(self, path, map_location=None):
def init_backbone(self, path):
""" Initializes the backbone weights for training. """
state_dict = torch.load(path, map_location)
state_dict = torch.load(path, map_location='cpu')
state_dict = OrderedDict([(self.transform_key(k), v) for k,v in state_dict.items()])

self.load_state_dict(state_dict, strict=False)
Expand Down
4 changes: 2 additions & 2 deletions yolact.py
Expand Up @@ -1182,7 +1182,7 @@ def save_weights(self, path):

def load_weights(self, path, args=None):
""" Loads weights from a compressed save file. """
state_dict = torch.load(path, map_location=torch.device(torch.cuda.current_device()))
state_dict = torch.load(path, map_location='cpu')

# Get all possible weights
cur_state_dict = self.state_dict()
Expand Down Expand Up @@ -1268,7 +1268,7 @@ def load_weights(self, path, args=None):
def init_weights(self, backbone_path):
""" Initialize weights for training. """
# Initialize the backbone with the pretrained weights.
self.backbone.init_backbone(backbone_path, map_location=torch.device(torch.cuda.current_device()))
self.backbone.init_backbone(backbone_path)

conv_constants = getattr(nn.Conv2d(1, 1, 1), '__constants__')

Expand Down

0 comments on commit 9ec7b0d

Please sign in to comment.