## Imports, Config and Seeding

In [50]:
import timm
import torch
import torchvision
from typing import Dict, Union, Callable, OrderedDict
import os, random
import numpy as np

In [51]:
def seed_all(seed: int = 1992) -> None:
    """Seed all random number generators."""
    print(f"Using Seed Number {seed}")

    os.environ["PYTHONHASHSEED"] = str(seed)  # set PYTHONHASHSEED env var at fixed value
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)  # pytorch (both CPU and CUDA)
    np.random.seed(seed)  # for numpy pseudo-random generator
    # set fixed value for python built-in pseudo-random generator
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True


def seed_worker(_worker_id) -> None:
    """Seed a worker with the given ID."""
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


seed_all(seed=1992)

Using Seed Number 1992


In [52]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [53]:
# resnet18_pretrained_true = timm.create_model(model_name = "resnet34", pretrained=True, num_classes=10).to(DEVICE)

In [54]:
norm = torch.nn.InstanceNorm2d(num_features=3, track_running_stats=True)
print(norm.running_mean, norm.running_var)

tensor([0., 0., 0.]) tensor([1., 1., 1.])


In [55]:
x = torch.randn(2, 3, 24, 24)

out = norm(x)
print(norm.running_mean, norm.running_var)

out = norm(x)
print(norm.running_mean, norm.running_var)


out = norm(x)
print(norm.running_mean, norm.running_var)

tensor([-1.3414e-03, -4.7338e-05,  1.1239e-03]) tensor([1.0010, 0.9984, 0.9989])
tensor([-2.5486e-03, -8.9943e-05,  2.1355e-03]) tensor([1.0018, 0.9969, 0.9979])
tensor([-0.0036, -0.0001,  0.0030]) tensor([1.0026, 0.9956, 0.9970])


In [20]:
norm.eval()
out = norm(x)
print(norm.running_mean, norm.running_var)

tensor([-0.0160, -0.0018,  0.0068]) tensor([1.0002, 1.0082, 0.9904])


In [56]:
def freeze_batchnorm_layers(model: Callable) -> None:
    """Freeze the batchnorm layers of a PyTorch model.

    Args:
        model (CustomNeuralNet): model to be frozen.

    Example:
        >>> model = timm.create_model("efficientnet_b0", pretrained=True)
        >>> model.apply(freeze_batchnorm_layers) # to freeze during training
    """
    # https://discuss.pytorch.org/t/how-to-freeze-bn-layers-while-training-the-rest-of-network-mean-and-var-wont-freeze/89736/19
    # https://discuss.pytorch.org/t/should-i-use-model-eval-when-i-freeze-batchnorm-layers-to-finetune/39495/3
    classname = model.__class__.__name__

    for module in model.modules():
        if isinstance(module, torch.nn.InstanceNorm2d):
            module.eval()
        if isinstance(module, torch.nn.BatchNorm2d):
            
            if hasattr(module, "weight"):
                module.weight.requires_grad_(False)
            if hasattr(module, "bias"):
                module.bias.requires_grad_(False)
            module.eval()

In [57]:
norm.apply(freeze_batchnorm_layers)

InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)

In [58]:
out = norm(x)

In [59]:
norm.running_mean, norm.running_var

(tensor([-0.0036, -0.0001,  0.0030]), tensor([1.0026, 0.9956, 0.9970]))