In [2]:
import torch
import torchvision.models as models
from torch import nn

In [4]:
model = models.resnet50(pretrained=True)

# Find indexes of ReLU layers
# relu_indexes = [i for i, layer in enumerate(model.modules()) if isinstance(layer, torch.nn.ReLU)]

In [6]:
relu_name = []
for name, module in model.named_modules():
    if "relu" in name:
        relu_name.append(name)
        

In [7]:
relu_name

['relu',
 'layer1.0.relu',
 'layer1.1.relu',
 'layer1.2.relu',
 'layer2.0.relu',
 'layer2.1.relu',
 'layer2.2.relu',
 'layer2.3.relu',
 'layer3.0.relu',
 'layer3.1.relu',
 'layer3.2.relu',
 'layer3.3.relu',
 'layer3.4.relu',
 'layer3.5.relu',
 'layer4.0.relu',
 'layer4.1.relu',
 'layer4.2.relu']

In [8]:

def truncate_model_to_relu(model, target_relu_name):
    new_model = nn.Sequential()
    found_target = False

    def recursive_add(module, prefix=''):
        nonlocal found_target
        for name, child in module.named_children():
            full_name = f"{prefix}_{name}" if prefix else name
            if isinstance(child, nn.ReLU) and full_name.replace('_', '.') == target_relu_name:
                new_model.add_module(full_name, child)
                found_target = True
                return
            elif isinstance(child, nn.Sequential) or len(list(child.children())) > 0:
                new_module = nn.Sequential()
                new_model.add_module(full_name, new_module)
                recursive_add(child, full_name)
                if found_target:
                    return
            else:
                new_model.add_module(full_name, child)
            
            if found_target:
                return

    recursive_add(model)
    return new_model

In [13]:
truncated_model = truncate_model_to_relu(model, relu_name[1])

In [14]:
truncated_model

Sequential(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential()
  (layer1_0): Sequential()
  (layer1_0_conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (layer1_0_bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_0_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1_0_bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_0_conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (layer1_0_bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_0_relu): ReLU(inplace=True)
)