In [1]:
import torch
import torch.nn as nn

In [168]:
class MyModel(nn.Module):
    
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.fc1 = nn.Linear(4,3)
        self.fc2 = nn.Linear(3,2)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x0 = x
        print('x0 == ', x0)
        
        x1 = self.fc1(x0)
        print('Before Dropout == ', x1)
        
        x2 = self.dropout(x1)
        print('After Dropout == ', x2)
        
        x3 = self.fc2(x2)
        print('FC2 == ', x3)
        
        return x
    
model = MyModel()

In [169]:
a = torch.randn(1,4)
print(a)

tensor([[-0.7419, -0.4527, -0.0508, -0.3620]])


### model.eval()
- Dropout이나 Batch Normalize 등을 off한다.

In [174]:
model.train()
model(a)

x0 ==  tensor([[-0.7419, -0.4527, -0.0508, -0.3620]])
Before Dropout ==  tensor([[0.1675, 0.4831, 0.2952]], grad_fn=<AddmmBackward>)
After Dropout ==  tensor([[0.0000, 0.9662, 0.0000]], grad_fn=<MulBackward0>)
FC2 ==  tensor([[-0.6466, -0.0197]], grad_fn=<AddmmBackward>)


tensor([[-0.7419, -0.4527, -0.0508, -0.3620]])

In [175]:
model.eval()
model(a)

x0 ==  tensor([[-0.7419, -0.4527, -0.0508, -0.3620]])
Before Dropout ==  tensor([[0.1675, 0.4831, 0.2952]], grad_fn=<AddmmBackward>)
After Dropout ==  tensor([[0.1675, 0.4831, 0.2952]], grad_fn=<AddmmBackward>)
FC2 ==  tensor([[-0.3929,  0.1284]], grad_fn=<AddmmBackward>)


tensor([[-0.7419, -0.4527, -0.0508, -0.3620]])

### torch.no_grad()
- 자동미분(Autograd)를 off한다.
- 메모리 사용 감소

In [176]:
with torch.no_grad():
    model.train()
    model(a)

x0 ==  tensor([[-0.7419, -0.4527, -0.0508, -0.3620]])
Before Dropout ==  tensor([[0.1675, 0.4831, 0.2952]])
After Dropout ==  tensor([[0.0000, 0.0000, 0.5904]])
FC2 ==  tensor([[-0.1809,  0.3306]])


In [177]:
with torch.no_grad():
    model.eval()
    model(a)

x0 ==  tensor([[-0.7419, -0.4527, -0.0508, -0.3620]])
Before Dropout ==  tensor([[0.1675, 0.4831, 0.2952]])
After Dropout ==  tensor([[0.1675, 0.4831, 0.2952]])
FC2 ==  tensor([[-0.3929,  0.1284]])
