Skip to content

Commit

Permalink
update: remove unused variables and update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
khanrc committed Nov 25, 2018
1 parent f9a2cbf commit 029de08
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion models/augment_cells.py
@@ -1,6 +1,6 @@
""" CNN cell for network augmentation """
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import ops
import genotypes as gt

Expand Down
7 changes: 4 additions & 3 deletions models/augment_cnn.py
@@ -1,15 +1,15 @@
""" CNN for network augmentation """
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.augment_cells import AugmentCell
from models import ops
import genotypes as gt


class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def __init__(self, input_size, C, n_classes):
""" assuming input size 7x7 or 8x8 """
assert input_size == 7 or input_size == 8
assert input_size in [7, 8]
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
Expand Down Expand Up @@ -95,6 +95,7 @@ def forward(self, x):
return logits, aux_logits

def drop_path_prob(self, p):
""" Set drop path probability """
for module in self.modules():
if isinstance(module, ops.DropPath_):
module.p = p
3 changes: 2 additions & 1 deletion models/ops.py
@@ -1,3 +1,4 @@
""" Operations """
import torch
import torch.nn as nn
import genotypes as gt
Expand All @@ -21,7 +22,7 @@
def drop_path_(x, drop_prob, training):
if training and drop_prob > 0.:
keep_prob = 1. - drop_prob
# per data point mask
# per data point mask; assuming x in cuda.
mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
x.div_(keep_prob).mul_(mask)

Expand Down
3 changes: 1 addition & 2 deletions models/search_cells.py
@@ -1,8 +1,7 @@
""" CNN cell for architecture search """
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import ops
import genotypes as gt


class SearchCell(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions models/search_cnn.py
@@ -1,3 +1,4 @@
""" CNN for architecture search """
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -65,7 +66,6 @@ def _init_alphas(self):

self.alpha_normal = nn.ParameterList()
self.alpha_reduce = nn.ParameterList()
device = torch.device('cuda:0')

for i in range(self.n_nodes):
self.alpha_normal.append(nn.Parameter(1e-3*torch.randn(i+2, n_ops)))
Expand All @@ -77,7 +77,7 @@ def forward(self, x):
weights_normal = [F.softmax(alpha, dim=-1) for alpha in self.alpha_normal]
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in self.alpha_reduce]

for i, cell in enumerate(self.cells):
for cell in self.cells:
weights = weights_reduce if cell.reduction else weights_normal
s0, s1 = s1, cell(s0, s1, weights)

Expand Down
2 changes: 1 addition & 1 deletion visualize.py
@@ -1,5 +1,5 @@
""" Network architecture visualizer using graphviz """
import sys
import genotypes
from graphviz import Digraph
import genotypes as gt

Expand Down

0 comments on commit 029de08

Please sign in to comment.