In [1]:
import torch
import torch.nn as nn

In [2]:
class MyModel(nn.Module):
    
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.fc1 = nn.Linear(4,3)
        self.fc2 = nn.Linear(3,2)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x0 = x
        print('x0 == ', x0)
        
        x1 = self.fc1(x0)
        print('Before Dropout == ', x1)
        
        x2 = self.dropout(x1)
        print('After Dropout == ', x2)
        
        x3 = self.fc2(x2)
        print('FC2 == ', x3)
        
        return x
    
model = MyModel()

In [3]:
a = torch.randn(1,4)
print(a)

tensor([[ 1.0146, -0.9661,  0.8195, -0.2608]])


### model.eval()
- Dropout이나 Batch Normalize 등을 off한다.

In [4]:
model.train()
model(a)

x0 ==  tensor([[ 1.0146, -0.9661,  0.8195, -0.2608]])
Before Dropout ==  tensor([[0.2435, 0.7484, 0.7701]], grad_fn=<AddmmBackward>)
After Dropout ==  tensor([[0.4869, 0.0000, 0.0000]], grad_fn=<MulBackward0>)
FC2 ==  tensor([[-0.6646,  0.2437]], grad_fn=<AddmmBackward>)


tensor([[ 1.0146, -0.9661,  0.8195, -0.2608]])

In [5]:
model.eval()
model(a)

x0 ==  tensor([[ 1.0146, -0.9661,  0.8195, -0.2608]])
Before Dropout ==  tensor([[0.2435, 0.7484, 0.7701]], grad_fn=<AddmmBackward>)
After Dropout ==  tensor([[0.2435, 0.7484, 0.7701]], grad_fn=<AddmmBackward>)
FC2 ==  tensor([[-0.0941,  0.5149]], grad_fn=<AddmmBackward>)


tensor([[ 1.0146, -0.9661,  0.8195, -0.2608]])

### torch.no_grad()
- 자동미분(Autograd)를 off한다.
- 메모리 사용 감소

In [6]:
with torch.no_grad():
    model.train()
    model(a)

x0 ==  tensor([[ 1.0146, -0.9661,  0.8195, -0.2608]])
Before Dropout ==  tensor([[0.2435, 0.7484, 0.7701]])
After Dropout ==  tensor([[0.4869, 0.0000, 1.5402]])
FC2 ==  tensor([[-0.0113,  0.7871]])


In [7]:
with torch.no_grad():
    model.eval()
    model(a)

x0 ==  tensor([[ 1.0146, -0.9661,  0.8195, -0.2608]])
Before Dropout ==  tensor([[0.2435, 0.7484, 0.7701]])
After Dropout ==  tensor([[0.2435, 0.7484, 0.7701]])
FC2 ==  tensor([[-0.0941,  0.5149]])
