Skip to content

Commit

Permalink
minor edits
Browse files Browse the repository at this point in the history
  • Loading branch information
vboddeti committed Jun 28, 2018
1 parent f269ef3 commit 2c5e11f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 21 deletions.
10 changes: 5 additions & 5 deletions args.txt
Expand Up @@ -24,7 +24,7 @@ label_filename_test = ./test_label
batch_size = 64

# model options
model_type = resnet18
model_type = preactresnet18
model_options = {"nchannels": 3, "nfilters": 64, "nclasses": 10}
loss_type = Classification
loss_options = {}
Expand All @@ -35,14 +35,14 @@ resolution_high = 32
resolution_wide = 32

manual_seed = 0
nepochs = 200
nepochs = 350

optim_method = SGD
learning_rate = 1e-1
optim_options = {"momentum": 0.9, "weight_decay": 0.0}
optim_options = {"momentum": 0.9, "weight_decay": 5e-4}

scheduler_method = CosineAnnealingLR
scheduler_options = {"T_max": 25, "eta_min": 1e-6}
scheduler_method = MultiStepLR
scheduler_options = {"milestones": [150, 250]}

# cpu/gpu settings
cuda = True
Expand Down
17 changes: 3 additions & 14 deletions models/preactresnet.py
Expand Up @@ -81,9 +81,8 @@ def forward(self, x):
class PreActResNet(nn.Module):

def __init__(self, block, layers, nchannels=3, nfilters=64,
ndim=512, nclasses=1000):
nclasses=1000):
super(PreActResNet, self).__init__()
self.ndim = ndim
self.nclasses = nclasses
self.nchannels = nchannels
self.nfilters = nfilters
Expand All @@ -105,9 +104,6 @@ def __init__(self, block, layers, nchannels=3, nfilters=64,
stride=2)
self.fc = nn.Linear(8 * self.nfilters * block.expansion, self.nclasses)

if self.nclasses > 0:
self.fc2 = nn.Linear(self. ndim, self.nclasses)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
Expand All @@ -116,7 +112,7 @@ def _make_layer(self, block, planes, num_blocks, stride):
self.inplanes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x, features=False):
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(self.relu1(self.bn1(x)))

Expand All @@ -129,14 +125,7 @@ def forward(self, x, features=False):
x = x.view(x.size(0), -1)
x = self.fc(x)

if self.nclasses > 0:
if features is True:
return [x]
else:
y = self.fc2(x)
return [x, y]
else:
return [x]
return x


def preactresnet18(**kwargs):
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Expand Up @@ -49,10 +49,10 @@ def _debuginfo(self, *message):
'\033[95m', self.__class__.__name__, '\033[94m', message, '\033[0m')


def readcsvfile(filename):
def readcsvfile(filename, delimiter=','):
with open(filename, 'r') as f:
content = []
reader = csv.reader(f, delimiter=" ")
reader = csv.reader(f, delimiter=delimiter)
for row in reader:
content.append(row)
f.close()
Expand Down

0 comments on commit 2c5e11f

Please sign in to comment.