In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from functools import partial

In [4]:
resnet = torchvision.models.resnet152(weights=True)



In [7]:
#get the number of parameters
cnt = 0
for n,p in resnet.named_parameters():
  cnt+=p.numel()

In [12]:
for n, child in resnet.named_children():
  print(n)


conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc


In [20]:
#names of notation of layers
def conv(in_channels, out_channels, kernel_size, stride,padding = 0):
  return nn.Conv2d(in_channels, out_channels, kernel_size, stride,padding,bias = False)

conv1x1 = partial(conv, kernel_size = 1, stride = 1, padding = 0)
conv3x3 = partial(conv, kernel_size = 3, stride = 1, padding = 1)

def bn(num_features):
  return nn.BatchNorm2d(num_features,eps = 1e-5, momentum  = .1, affine = True, track_running_stats=True)


In [34]:
class Bottleneck(nn.Module):
  def __init__(self,in_channels, middle_channels, out_channels, stride = 2):
    #super(nn.Module, self).__init__()
    self.conv1 = conv1x1(in_channels, middle_channels)
    self.bn1 = bn(middle_channels)
    if stride == 1:
      self.conv2 = conv3x3(middle_channels, middle_channels)
    else:
      self.conv2 = conv(middle_channels, middle_channels, 3, stride, 1)
    self.bn2 = bn(middle_channels)
    self.conv3 = conv1x1(middle_channels, out_channels)
    self.bn3 = bn(out_channels)
    self.relu = nn.ReLU(inplace = False)

    if stride > 1 or in_channels!= out_channels:
      self.downsample = nn.Sequential(conv(in_channels, out_channels,kernel_size = 1, stride = stride), bn(out_channels))

    def forward(self,x):
      y= self.conv1(x)
      y= self.bn1(x)
      y= self.relu(x)
      y= self.conv2(x)
      y= self.bn2(x)
      y= self.relu(x)
      y= self.conv3(x)
      y= self.bn3(x)
      if self.downsample is not None:
        x = self.downsample(x)
      y = y+x
      return y


In [35]:
model  = Bottleneck(64,64,256,stride = 1)
model.load_state_dict(resnet.layer1[0].state_dict())



AttributeError: ignored