Skip to content

Commit

Permalink
updated models
Browse files Browse the repository at this point in the history
  • Loading branch information
ekshaks committed Nov 9, 2019
1 parent a2512a4 commit 3c8f64c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 4 additions & 5 deletions models/resnet.py
Expand Up @@ -6,11 +6,11 @@
# Updated to add shape annotations (BasicBlock and ResNet modules)

import sys
sys.path.append('../')
from tsalib import dim_vars

B, C, Ci, Co = dim_vars('Batch Channels ChannelsIn ChannelsOut')
H, W, Ex = dim_vars('Height Width BlockExpansion')

B, C, Ci, Co = dim_vars('Batch(b):10 Channels(c):3 ChannelsIn(ci) ChannelsOut(co)')
H, W, Ex = dim_vars('Height(h):224 Width(w):224 BlockExpansion(e):1')

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
Expand Down Expand Up @@ -147,7 +147,6 @@ def _make_layer(self, block, planes, blocks, stride=1):
return nn.Sequential(*layers)

def forward(self, x: (B, 3, H, W)): #H = W = 224

x: (B, 64, H//2, W//2) = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
Expand All @@ -160,7 +159,7 @@ def forward(self, x: (B, 3, H, W)): #H = W = 224

x: (B, 512*Ex, 1, 1) = self.avgpool(x)
x: (B, 512*Ex) = x.view(x.size(0), -1)
x: (B, num_classes) = self.fc(x)
x: (B, NC) = self.fc(x)

return x

Expand Down
4 changes: 4 additions & 0 deletions models/snippets_tf.py
@@ -1,4 +1,7 @@

from tsalib import dim_vars
from tsalib.backend import get_shape_list

def modeling_embedding_lookup(input_ids: 'bti'):
# illustrates local dim var usage, i is not declared globaly as dimvar
B, T, D = dim_vars('B(b):13 L(t):7 D(d):32')
Expand Down Expand Up @@ -93,4 +96,5 @@ def transpose_for_scores(input_tensor: 'b*t,d', batch_size: 'b', num_attention_h

####################

if __name__ == '__main__':

0 comments on commit 3c8f64c

Please sign in to comment.