<a href="https://colab.research.google.com/github/jhy9968/ECE6179_project/blob/main/Encoder_decoder_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [57]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.models.resnet import ResNet18_Weights

**Encoder Decoder Class**

In [62]:
class Encoder(nn.Module):
  def __init__ (self, base_model=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1, progress=False)):
    super(Encoder, self).__init__()
    self.block1 = nn.Sequential(*list((base_model.children()))[:5])
    self.block2 = nn.Sequential(*list((base_model.children()))[5])
    self.block3 = nn.Sequential(*list((base_model.children()))[6])
    self.block4 = nn.Sequential(*list((base_model.children()))[7])

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)
    return x

  def print_model(self):
    print(self)

  def freeze_param(self, block):
    for i, child in enumerate(self.children()):
      if i == block-1:
        for param in child.parameters():
          param.requires_grad = False
    print('Freeze block '+str(block)+' parameters')

  def unfreeze_param(self, block):
    for i, child in enumerate(self.children()):
      if i == block-1:
        for param in child.parameters():
          param.requires_grad = True
    print('Unfreeze block '+str(block)+' parameters')

In [63]:
class Decoder(nn.Module):
  def __init__ (self, inplanes = 512, intMed_planes = 64):
    super(Decoder, self).__init__()
    self.inplanes      = inplanes
    self.intMed_planes = intMed_planes

    self.convTrans1 = nn.ConvTranspose2d(in_channels = self.inplanes, out_channels = self.intMed_planes, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv2 = nn.Conv2d(in_channels = self.intMed_planes, out_channels = self.intMed_planes, kernel_size = 3, stride = 1, padding = 1)
    self.convTrans3 = nn.ConvTranspose2d(in_channels = self.intMed_planes, out_channels = self.intMed_planes, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv4 = nn.Conv2d(in_channels = self.intMed_planes, out_channels = 3, kernel_size = 3, stride = 1, padding = 1)

 # Output padding is here to match the size. It needs to be careful on this extra line of zeros when building the loss function.

  def forward (self, x):
    
    x = self.convTrans1(x)
    x = self.conv2(x)
    x = self.convTrans3(x)
    x = self.conv4(x)

    return x

**Encoder instructions**

Generate an encoder

In [64]:
encoder = Encoder()

Print the encoder structure

> Those four chunks are named as: blcok1, blcok2, blcok3 and blcok4





In [65]:
encoder.print_model()

Encoder(
  (block1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


Freeze or unfreeze a block of the encoder

In [52]:
# To freeze a block, call encoder.freeze_param() with the ID of that block (1 to 4)

# e.g., to freeze the block2
encoder.freeze_param(2)

Freeze block 2 parameters


In [66]:
# To unfreeze a block, call encoder.unfreeze_param() with the ID of that block (1 to 4)

# e.g., to freeze the block2
encoder.unfreeze_param(2)

Unfreeze block 2 parameters


For the autoencoder, call something like:

```
autoencoder.encoder.freeze_param(block_ID)
autoencoder.encoder.unfreeze_param(block_ID)
```

