-
Notifications
You must be signed in to change notification settings - Fork 1
/
segmentation_nn.py
72 lines (60 loc) · 2.67 KB
/
segmentation_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""SegmentationNN"""
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
#from torchvision.models.vgg import model_urls
class SegmentationNN(nn.Module):
def __init__(self, num_classes=23):
super(SegmentationNN, self).__init__()
########################################################################
# YOUR CODE #
########################################################################
self.num_classes = num_classes
#model_urls['vgg16'] = model_urls['vgg16'].replace('https://', 'http://')
self.vgg_feat = models.vgg19(pretrained=True).features
self.fcn = nn.Sequential(
nn.Conv2d(512, 1024, 7),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Conv2d(1024, 2048, 1),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Conv2d(2048, num_classes, 1)
)
########################################################################
# END OF YOUR CODE #
########################################################################
def forward(self, x):
"""
Forward pass of the convolutional neural network. Should not be called
manually but by calling a model instance directly.
Inputs:
- x: PyTorch input Variable
"""
########################################################################
# YOUR CODE #
########################################################################
x_input = x
x = self.vgg_feat(x)
x = self.fcn(x)
x = F.upsample(x, x_input.size()[2:], mode='bilinear').contiguous()
########################################################################
# END OF YOUR CODE #
########################################################################
return x
@property
def is_cuda(self):
"""
Check if model parameters are allocated on the GPU.
"""
return next(self.parameters()).is_cuda
def save(self, path):
"""
Save model with its parameters to the given path. Conventionally the
path should end with "*.model".
Inputs:
- path: path string
"""
print('Saving model... %s' % path)
torch.save(self, path)