Skip to content

Commit

Permalink
add instancenorm
Browse files Browse the repository at this point in the history
  • Loading branch information
nhseob committed Sep 5, 2018
1 parent c3dda1c commit 2362089
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions models/resnet.py
@@ -1,4 +1,5 @@
import torch.nn as nn
import functools
import math

def conv3x3(in_planes, out_planes, stride=1):
Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(self, depth, num_classes=1000, norm_type=None, basicblock=False):
from torch.nn import InstanceNorm2d as Normlayer
elif norm_type == 'bin':
from .batchinstancenorm import BatchInstanceNorm2d as Normlayer
self.normlayer = Normlayer
self.normlayer = functools.partial(Normlayer, affine=True)

self.inplanes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
Expand All @@ -114,7 +115,7 @@ def __init__(self, depth, num_classes=1000, norm_type=None, basicblock=False):
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, self.normlayer):
elif isinstance(m, self.normlayer.func):
m.weight.data.fill_(1)
m.bias.data.zero_()

Expand Down

0 comments on commit 2362089

Please sign in to comment.