In [1]:
import numpy as np
import torch
from sklearn.datasets import load_sample_images

In [3]:
sample_img = np.stack(load_sample_images()['images'])
sample_img = torch.tensor(sample_img, dtype=torch.float32) / 255


In [4]:
sample_img.shape

torch.Size([2, 427, 640, 3])

In [6]:
sample_img_permuted = sample_img.permute(0,3,1,2)
sample_img_permuted.shape

torch.Size([2, 3, 427, 640])

In [7]:
import torchvision
import torchvision.transforms.v2 as T

In [8]:
cropped_img = T.CenterCrop((70,120)) (sample_img_permuted)
cropped_img.shape

torch.Size([2, 3, 70, 120])

In [9]:
import torch.nn as nn

torch.manual_seed(42)
conv_lay = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7)
fmaps = conv_lay(cropped_img)
fmaps.shape

torch.Size([2, 32, 64, 114])

#other way use padding

In [10]:
conv_lay = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7,padding="same")
fmaps = conv_lay(cropped_img)
fmaps.shape

torch.Size([2, 32, 70, 120])

In [11]:
conv_lay.weight.shape

torch.Size([32, 3, 7, 7])

In [12]:
conv_lay.bias.shape

torch.Size([32])

#pooling layer code

In [14]:
max_pool = nn.MaxPool2d(kernel_size=2)

build custom max_depth pooling layer

In [15]:
import torch.nn.functional as F
class DepthPool(torch.nn.Module):
 def __init__(self, kernel_size, stride=None, padding=0):
   super().__init__()
   self.kernel_size = kernel_size
   self.stride = stride if stride is not None else kernel_size
   self.padding = padding
 def forward(self, inputs):
   batch, channels, height, width = inputs.shape
   Z = inputs.view(batch, channels, height * width) # merge spatial dims
   Z = Z.permute(0, 2, 1) # switch spatial and channels dims
   Z = F.max_pool1d(Z, kernel_size=self.kernel_size,
   stride=self.stride,
   adding=self.padding) # compute max pool
   Z = Z.permute(0, 2, 1) # switch back spatial and channels dims
   return Z.view(batch, -1, height, width) # unmerge spatial dims


In [17]:
global_avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
output = global_avg_pool(cropped_img)

In [18]:
output = cropped_img.mean(dim=(2, 3), keepdim=True)