# torch.no_grad() varsus param.requires_grad
- torch.no_grad()
    - 定义了一个上下文管理器，隐式的不进行梯度更新，不会改变requires_grad
    - 适用于eval阶段，或者model forward的过程中某些模型不更新梯度的模块(此时这些模块仅进行特征提取(前向计算)，不反向更新)
    - is a context manager and is used to prevent calculating gradients
- param.requires_grad
    - 显式的frozen掉一些moudel(layer)的梯度更新
    - layer/module 级别
    - 可能会更灵活
    - to freeze part of your model and train the rest

In [None]:
from transformers import BertModel
import torch
from torch import nn

In [None]:
model_name = "bert-base_uncased"

In [None]:
bert = BertModel.from_pretrained(model_name)

In [None]:
def calc_learnable_params(model):
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            total_params += param.numel()
    return total_params

In [None]:
calc_learnable_params(bert)

In [None]:
with torch.no_grad():
    print(calc_learnable_params(bert))

In [None]:
for name, param in bert.named_parameters():
    if param.requires_grad:
        param.requires_grad = False

calc_learnable_params(bert)