Skip to content

Commit

Permalink
Merge pull request pytorch#645 from jlin27/master
Browse files Browse the repository at this point in the history
Correct errors in the TorchModule code to match the sequential version
  • Loading branch information
Jessica Lin committed Sep 5, 2019
2 parents 1bc0efe + 03ba8fa commit ffecdf7
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions advanced_source/cpp_frontend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -769,24 +769,24 @@ modules in the ``forward()`` method of a module we define ourselves:
.. code-block:: cpp
struct GeneratorImpl : nn::Module {
GeneratorImpl()
: conv1(nn::Conv2dOptions(kNoiseSize, 512, 4)
GeneratorImpl(int kNoiseSize)
: conv1(nn::Conv2dOptions(kNoiseSize, 256, 4)
.with_bias(false)
.transposed(true)),
batch_norm1(512),
conv2(nn::Conv2dOptions(512, 256, 4)
batch_norm1(256),
conv2(nn::Conv2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
batch_norm2(256),
conv3(nn::Conv2dOptions(256, 128, 4)
batch_norm2(128),
conv3(nn::Conv2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
batch_norm3(128),
conv4(nn::Conv2dOptions(128, 64, 4)
batch_norm3(64),
conv4(nn::Conv2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.with_bias(false)
Expand All @@ -796,19 +796,28 @@ modules in the ``forward()`` method of a module we define ourselves:
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)) {}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::relu(batch_norm4(conv4(x)));
x = torch::tanh(conv5(x));
return x;
}
nn::Conv2d conv1, conv2, conv3, conv4, conv5;
nn::BatchNorm batch_norm1, batch_norm2, batch_norm3, batch_norm4;
.transposed(true))
{
// register_module() is needed if we want to use the parameters() method later on
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("conv4", conv4);
register_module("batch_norm1", batch_norm1);
register_module("batch_norm2", batch_norm1);
register_module("batch_norm3", batch_norm1);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));
return x;
}
nn::Conv2d conv1, conv2, conv3, conv4;
nn::BatchNorm batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(Generator);
Expand Down

0 comments on commit ffecdf7

Please sign in to comment.