In [1]:
# Dependencies
import torch
import torchvision

from typing import Tuple

In [2]:
#
def update_models_inplace_to_encoder_decoder_pair(
	encoder: torch.nn.Module, decoder: torch.nn.Module,
	encoder_decoder_boundary: str
) -> Tuple[torch.nn.Module, torch.nn.Module]:
	r"""Converts the given `encoder` and `decoder` networks into an
	encoder-decoder pair. The input `encoder` and `decoder` instances are
	CHANGED IN PLACE PERMANENTLY.

	Parameters
	----------
	encoder:
		The model to be converted into an encoder.

	decoder:
		The model to be converted into a decoder.

	encoder_decoder_boundary:
		The name of the layer till which encoder exists. If both the models
		have architecture ['A', 'B', 'C'] with the boundary at 'B', the
		`encoder` will only have 'A' as its architecture while the `decoder`
		will have 'B', 'C' as its architecture.

	Returns
	-------
	updated_encoder:
		The encoder model with decoder layers set to identity.

	updated_decoder:
		The decoder model with encoder layers set to identity.

	Notes
	-----
	It is guaranteed that the two networks are two DIFFERENT instances of the
	same architecture. Based on the `encoder_decoder_boundary`, all the
	NEXT layers of the `encoder` and all the PREVIOUS layers of the `decoder`
	to `torch.nn.Identity`.
	"""
	# Obtain the ordered list of all the modules in the architectures.
	enc_modules = [enc_module_name for enc_module_name in encoder._modules]
	dec_modules = [dec_module_name for dec_module_name in decoder._modules]
	# Check if the lists match.
	assert enc_modules == dec_modules, \
		'[ERROR] encoder and decoder modules must match.'
	# Set each module correctly.
	is_enc_layer: bool = True
	for layer_name in enc_modules:
		# If the layer name matches boundary, the decoder begins.
		if layer_name == encoder_decoder_boundary:
			is_enc_layer = False
		# If encoder layer, set corresponding decoder layer to identity and vice versa.
		if is_enc_layer:
			setattr(decoder, layer_name, torch.nn.Identity())
		else:
			setattr(encoder, layer_name, torch.nn.Identity())
	# Return handles to the encoder and the decoder.
	return encoder, decoder

In [3]:
encoder = torchvision.models.resnet18(pretrained=False, progress=True)
decoder = torchvision.models.resnet18(pretrained=False, progress=True)

In [4]:
print('*'*79)
print('encoder:\n{}'.format(encoder))
print('*'*79)
print('decoder:\n{}'.format(decoder))

*******************************************************************************
encoder:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05,

In [5]:
updated_encoder, updated_decoder = update_models_inplace_to_encoder_decoder_pair(
    encoder=encoder, decoder=decoder, encoder_decoder_boundary='layer2'
)

In [6]:
print('*'*79)
print('encoder:\n{}'.format(updated_encoder))
print('*'*79)
print('decoder:\n{}'.format(updated_decoder))

*******************************************************************************
encoder:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05,

In [7]:
# THE UPDATES ARE INPLACE!
print('*'*79)
print('encoder:\n{}'.format(encoder))
print('*'*79)
print('decoder:\n{}'.format(decoder))

*******************************************************************************
encoder:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05,