In [None]:
import torch
import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(self, n_in, n_out, use_identity=False, use_float16=True):
        super(EncoderBlock, self).__init__()
        self.use_float16 = use_float16
        self.id_path = nn.Identity() if use_identity else nn.Conv2d(n_in, n_out, kernel_size=1)
        self.res_path = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(n_in, n_out // 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(n_out // 4, n_out // 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(n_out // 4, n_out // 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(n_out // 4, n_out, kernel_size=1)
        )
        self._set_requires_grad_false()

    def _set_requires_grad_false(self):
        for param in self.id_path.parameters():
            param.requires_grad = False
        for param in self.res_path.parameters():
            param.requires_grad = False

    def forward(self, x):
        if self.use_float16:
            x = x.half()
            self.res_path = self.res_path.half()
            if not isinstance(self.id_path, nn.Identity):
                self.id_path = self.id_path.half()
        id_out = self.id_path(x)
        res_out = self.res_path(x)
        return id_out + res_out

class Encoder(nn.Module):
    def __init__(self, use_float16=True):
        super(Encoder, self).__init__()
        self.use_float16 = use_float16
        self.device = torch.device('cpu')
        self.blocks = nn.Sequential(
            nn.Conv2d(3, 256, kernel_size=7, padding=3),
            nn.Sequential(
                EncoderBlock(256, 256, use_identity=True, use_float16=self.use_float16),
                EncoderBlock(256, 256, use_identity=True, use_float16=self.use_float16),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ),
            nn.Sequential(
                EncoderBlock(256, 512, use_identity=False, use_float16=self.use_float16),
                EncoderBlock(512, 512, use_identity=True, use_float16=self.use_float16),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ),
            nn.Sequential(
                EncoderBlock(512, 1024, use_identity=False, use_float16=self.use_float16),
                EncoderBlock(1024, 1024, use_identity=True, use_float16=self.use_float16),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ),
            nn.Sequential(
                EncoderBlock(1024, 2048, use_identity=False, use_float16=self.use_float16),
                EncoderBlock(2048, 2048, use_identity=True, use_float16=self.use_float16)
            ),
            nn.Sequential(
                nn.ReLU(),
                nn.Conv2d(2048, 8192, kernel_size=1)
            )
        )
        self._set_requires_grad_false()

    def _set_requires_grad_false(self):
        for param in self.blocks[0].parameters():
            param.requires_grad = False
        for layer in self.blocks[1:]:
            for param in layer.parameters():
                param.requires_grad = False

    def forward(self, x):
        if self.use_float16:
            x = x.half()
            for layer in self.blocks:
                layer = layer.half()
        return self.blocks(x)

# Example usage:
encoder = Encoder()
input_tensor = torch.randn(1, 3, 224, 224)  # Example input tensor
output_tensor = encoder(input_tensor)
print(output_tensor.shape)


In [None]:
1