## This notebook illustrate how `model.train()` and `model.eval()` work.

In [1]:
import torch
from torch import nn

In [2]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False)
        )
        self.classifier = nn.Linear(512, 20, bias=False)
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten.
        
        if self.training is False:  # Not training.
            return x
        
        x = self.classifier(x)
        return x

In [3]:
network = Network()

network.train()
print(network.training)

network.eval()  # Not training.
print(network.training)

True
False


In [4]:
input = torch.randn(1, 1, 4, 4)

network.train()
output_1 = network(input)  # 20
print(output_1.size())

network.eval()
output_2 = network(input)  # 32 * 4 * 4 = 512
print(output_2.size())

[2022-03-19 15:22:29.491 1-8-1-cpu-py36-ml-t3-medium-62c6b413a5e8d67a1da6b0c48d04:288 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-03-19 15:22:29.841 1-8-1-cpu-py36-ml-t3-medium-62c6b413a5e8d67a1da6b0c48d04:288 INFO profiler_config_parser.py:102] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.
torch.Size([1, 20])
torch.Size([1, 512])
