In [8]:
import torch
import torch.nn as nn
import numpy as np

In [9]:
def layer_init(layer:torch.nn.Module,
               std=np.sqrt(2), 
               bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    # torch.nn.init.normal_(layer.weight, std)
    # torch.nn.init.normal_(layer.bias, std)
    return layer

In [51]:
class Resnet(nn.Module):
  def __init__(self):
    super().__init__()
    # TODO
    kernel_size = 3
    stride = 1
    in_channels = 144
    mid_channels = 128
    final_channels = 256
    num_layers = 3
    self.leakyrelu = nn.LeakyReLU()
    self.first_block = nn.Sequential(
          layer_init(nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, stride=stride,
                            padding=1, bias=True)),
          nn.BatchNorm2d(mid_channels),
          nn.GELU())
    
    self.conv_layers = nn.ModuleList()
    self.channels = [mid_channels for i in range(num_layers)]
    for i in range(num_layers):  
        conv_block = nn.Sequential(
                                layer_init(nn.Conv2d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride,
                                padding=1, bias=True)),
                                nn.BatchNorm2d(mid_channels),
                                nn.GELU())
        self.conv_layers.append(conv_block)
    self.is_avg_pooling = True
    self.final_block = nn.Sequential(
            nn.Conv2d(mid_channels, final_channels, kernel_size=kernel_size, stride=stride,
                            padding=1, bias=True),
            nn.BatchNorm2d(final_channels),
            nn.GELU())
    self.avg_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
    input_size = 14 * 14
    self.fc = layer_init(nn.Linear(input_size, 256))

  def forward(self, x):
    if len(x.shape) < 4:
      x = x.unsqueeze(0)
    x = self.first_block(x)
    shortcut = x
    for conv_block in self.conv_layers:
      x = conv_block(x)
      x += shortcut
      shortcut = x
    # x = self.final_block(x)
    # x = self.avg_pooling(x)
    # x = x.view(x.size(0), -1)
    x = x.view(x.size(0), x.size(1), -1)
    x = self.fc(x)
    return x

In [52]:
image_embed_net = Resnet()

In [53]:
total_num_params = 0
for name, p in image_embed_net.named_parameters():
    total_num_params += p.numel()
total_num_params 

955904

In [54]:
def image_embedding(image_embed_net, image: torch.Tensor):
    # image: [batch_size, 4, 84, 84]
    patch_width_height = (14, 14)
    batch_size = image.size(0)
    image = image.to(dtype=torch.float32) / 255.0
    image = image.reshape(batch_size, -1, *patch_width_height)
    # [32, 144, 14, 14]
    print(image.shape)
    embed = image_embed_net(image)
    return embed 

In [55]:
image = torch.randn([32, 4, 84, 84])
embed = image_embedding(image_embed_net, image)
embed.shape

torch.Size([32, 144, 14, 14])


torch.Size([32, 128, 256])