## 1. Residual Neural


When we stack more and more layers in a deep neural network, training becomes harder:

* **Vanishing/exploding gradients**: gradients shrink or grow as they backpropagate, making early layers learn very slowly or unstably.
* **Degradation problem**: simply adding more layers sometimes *reduces* training accuracy (not just test accuracy).

> The issue wasn’t overfitting — it was optimization difficulty.

---

#### 1.1 Key idea: “Residual” learning

Instead of making a stack of layers learn a direct mapping $H(x)$ from input $x$ to output, ResNet makes the stack learn a **residual mapping**:

$
F(x) = H(x) - x \quad \Rightarrow \quad H(x) = F(x) + x
$

So the block learns only the *change* to apply to the input. The original input is added back at the end via a **skip connection**.

---

#### 1.2 A residual block

**Standard block:**

```text
x ──> [Conv → BN → ReLU → Conv → BN] ──> + ──> ReLU ──> output
       ^                                   │
       └───────────── skip connection ─────┘
```

Mathematically:

$
\text{output} = \text{ReLU}(F(x; W) + x)
$

* $F(x; W)$: the output of the two convolutions (the “residual”)
* $x:$ the identity/skip connection input

This allows gradients to flow directly through the skip connection during backpropagation, which stabilizes training even with very deep networks.

#### 1.3 What does it mean "Skip Connection"

A **skip connection** (also called a **shortcut connection**) is literally what it sounds like:
a pathway that *skips over* one or more layers and feeds the input directly to a later point in the network.

---

**In a normal feed-forward block**

```
x ──> [Layer(s)] ──> output
```

All information flows through the layers.

---

**With a skip (shortcut) connection**

```
        ┌───────────────┐
x ──> [Layer(s)] ──> + ──> output
   └───────────────┘ ↑
        (skip x)─────┘
```

You still process (x) through some layers to get $F(x)$, **but at the same time you also send (x) forward unchanged** and add it back in at the end.

Mathematically:
$
\text{output} = F(x) + x
$

* $F(x)$: what the block learned (residual)
* $x$: original input passed along the shortcut path

---

#### 1.4 Two main skip-connection types

* **Identity skip** (when input/output have same shape): just add $x$ to $F(x)$.
* **Projection skip** (when shapes differ): apply a $1 \times 1$ convolution (and maybe stride) to $x$ before adding.

---


#### 1.5 The network want to learn F(x) or H(x)?


A plain stack of layers takes an input (x) and tries to directly learn
$
H(x) \quad \text{(desired mapping)}
$

What ResNet does instead

ResNet rewrites the problem so that the stacked layers learn the **residual function**

$
F(x) = H(x) - x
$

and then adds the input back:

$
\text{Output} = F(x) + x = H(x)
$

---

**Residual block forward pass**

$
y = x + F(x;,W)
$

where $F(x;W)$ are the layers with parameters $W$ (two convs, etc.), and $x$ is the input to the block.

---

#### 1.6 Gradient through a normal (plain) block

If you had

$
y = F(x;W)
$

then

$
\frac{\partial L}{\partial x}=\frac{\partial L}{\partial y}\frac{\partial y}{\partial x}=
\frac{\partial L}{\partial y}
\frac{\partial F(x;W)}{\partial x}
$


So all the gradient information must flow through the derivative of (F). If $\partial F/\partial x$ is very small (vanishing gradient), the gradient almost disappears before it reaches earlier layers.

---

Gradient with skip connection:

With
$
y = x + F(x;W)
$

we have

$
\frac{\partial L}{\partial x}=\frac{\partial L}{\partial y}
\frac{\partial (x + F(x;W))}{\partial x}
=\frac{\partial L}{\partial y},(I + \frac{\partial F}{\partial x})
$

Notice the **identity term (I)** coming from $\partial x/\partial x = 1$.

This means that even if $\partial F/\partial x$ is tiny, there is still a direct gradient path:

$
\frac{\partial L}{\partial x} \approx \frac{\partial L}{\partial y}
$

So the gradient flows directly back to the input (x) (and thus to layers before the block) without being multiplied by small numbers. That’s the “highway” for gradients people talk about.

---



* Yes, the gradient can go **directly to the layer that produced (x)** via the identity/skip branch.
* It doesn’t have to be squeezed entirely through the complicated (F(x)) path.
* This keeps earlier layers trainable even in very deep networks.

---



## 2. ResNet Implementation

* **`block_cls`** → a **class** (`BasicBlock` or `BottleNeck`)
* **`block_mod`** → an **instance** of that class (a module you add to the network)

#### 2.1 Basic residual block 

```python 
# ---------- Base interface for residual blocks ----------
class ResidualBlockBase(nn.Module):
    """
    Base class to type residual blocks.
    Child classes must define a class attribute `expansion` (int).
    """
    expansion: int = 1  # how many times `planes` the block outputs (Basic=1, Bottleneck=4)


# ---------- Basic residual block (ResNet-18/34 style) ----------
class BasicBlock(ResidualBlockBase):
    expansion: int = 1

    def __init__(self, in_channels: int, planes: int, stride: int = 1) -> None:
        super().__init__()
        out_channels = planes * self.expansion

        self.conv1 = nn.Conv2d(in_channels, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)

        # Identity or 1x1 projection on the skip path to match shape
        needs_projection = (stride != 1) or (in_channels != out_channels)
        if needs_projection:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.skip = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)  # <-- "skip connection" path

        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)

        y = self.conv2(y)
        y = self.bn2(y)

        y = y + identity          # residual addition: F(x) + x
        y = self.relu(y)
        return y
```



What exactly is `expansion: int = 1`?

It’s just a **class attribute** that tells the ResNet “builder” how many channels come **out** of a block relative to the `planes` you passed in.

* For a **BasicBlock**:

  * You pass `planes=64`.
  * The last conv also outputs 64 channels.
  * `expansion = 1`.
  * So the block’s output is `planes * expansion = 64 * 1 = 64`.

* For a **Bottleneck block** (ResNet-50/101/152):

  * You pass `planes=64`.
  * Inside the block:
    `1×1 conv (in→64)` → `3×3 conv (64→64)` → `1×1 conv (64→256)`.
  * The final conv outputs **256** channels.
  * `expansion = 4`.
  * So the block’s output is `planes * expansion = 64 * 4 = 256`.

The ResNet class can therefore always do:

```python
self.current_channels = planes * block_cls.expansion
```

…without hard-coding different rules for different block types.



```python
class ResidualBlockBase(nn.Module):
    expansion: int = 1

class BasicBlock(ResidualBlockBase): ...
class Bottleneck(ResidualBlockBase): ...
```

Inside the block, we have:

```python
self.conv1 = nn.Conv2d(in_channels, planes, kernel_size=3, stride=stride, padding=1, bias=False)`
```

so we follow the input for **stride** however in the second cobe we have **stride=1** so the input shape doesn't change

```python
self.conv2 = nn.Conv2d(planes, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
```


### Skip Connection Shape Change
In a residual block, we compute:

$$
y = F(x) + x
$$

This only makes sense if:

$$
\text{shape}(F(x)) = \text{shape}(x)
$$

That means:

* same **number of channels**,
* same **height**,
* same **width**.

Otherwise, the addition operation will fail.

---

#### When do shapes mismatch?

```python
self.conv1 = nn.Conv2d(in_channels, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(planes, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
```

So after the two convolutions:

* If `stride = 1`, the **spatial resolution** (H, W) stays the same.
* If `stride = 2`, the **spatial resolution halves**:
  $$
  H_{out} = \frac{H_{in}}{2}, \quad W_{out} = \frac{W_{in}}{2}
  $$
* The **number of channels** changes from `in_channels` → `out_channels = planes * expansion`.

Hence, when:

* `stride != 1`  → size mismatch in H, W.
* `in_channels != out_channels` → mismatch in C.

We can’t directly add `x` and `F(x)`.

---

#### The solution: the skip (projection) path

This part of your code detects when a mismatch occurs:

```python
needs_projection = (stride != 1) or (in_channels != out_channels)
```

If true, it applies a **1×1 convolution**:

```python
self.skip = nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel_size=1,
              stride=stride, bias=False),
    nn.BatchNorm2d(out_channels),
)
```

---

### Why it works (mechanically)

**Matching channels**

* The 1×1 convolution maps the number of **input channels** → **output channels**:

  $$
  (B, C_{in}, H, W) \xrightarrow[\text{1×1 conv}]{C_{out}} (B, C_{out}, H, W)
  $$

This changes only the **number of channels**, not the spatial size.

**Matching height and width**

* When `stride = 2`, the convolution also halves the **height** and **width**:

  $$
  (B, C_{in}, H, W) \xrightarrow[\text{stride}=2]{\text{1×1 conv}} (B, C_{out}, H/2, W/2)
  $$

Thus, after the projection:

* the skip path produces a tensor of shape **(B, out_channels, H_out, W_out)**,
* which now matches the shape of **F(x)** from the main branch.

---



#### 2.2 ResNet

```python
# ---------- ResNet backbone using a block CLASS (not an instance) ----------
class ResNet(nn.Module):
    """
    Generic ResNet that can be built with different block classes (e.g., BasicBlock, Bottleneck).

    Parameters
    ----------
    block_cls : Type[ResidualBlockBase]
        The CLASS of the residual block to instantiate (e.g., `BasicBlock`), not an instance.
    layers_per_stage : Sequence[int]
        Number of residual blocks in each of the 4 stages, e.g. [2,2,2,2] or [3,4,6,3].
    num_classes : int
        Output classes for the final classifier.
    in_channels : int
        Input image channels (3 for RGB).
    """
    def __init__(
        self,
        block_cls: Type[ResidualBlockBase],
        layers_per_stage: Sequence[int],
        num_classes: int = 1000,
        in_channels: int = 3,
    ) -> None:
        super().__init__()

        assert len(layers_per_stage) == 4, "Expected 4 stages"
        self.block_cls: Type[ResidualBlockBase] = block_cls

        # ----- Stem -----
        self.stem_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.stem_bn   = nn.BatchNorm2d(64)
        self.stem_relu = nn.ReLU(inplace=True)
        self.stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Track current #channels flowing into the next stage
        self.current_channels: int = 64

        # ----- Stages (conv2_x .. conv5_x) -----
        self.layer1 = self._build_stage(planes=64,  num_blocks=layers_per_stage[0], first_stride=1)
        self.layer2 = self._build_stage(planes=128, num_blocks=layers_per_stage[1], first_stride=2)
        self.layer3 = self._build_stage(planes=256, num_blocks=layers_per_stage[2], first_stride=2)
        self.layer4 = self._build_stage(planes=512, num_blocks=layers_per_stage[3], first_stride=2)

        # ----- Head -----
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc      = nn.Linear(512 * self.block_cls.expansion, num_classes)

        # He init for convs; BN to ones/zeros (standard for ResNet)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)

    def _build_stage(self, planes: int, num_blocks: int, first_stride: int) -> nn.Sequential:
        """
        Build one ResNet stage (e.g., conv3_x), returning a Sequential of residual block INSTANCES.

        - The first block may downsample via `first_stride`.
        - Subsequent blocks keep stride=1.
        """
        blocks: List[nn.Module] = []

        # 1) First block in the stage: may change spatial size & width
        block_mod = self.block_cls(  # <-- instantiate the CLASS
            in_channels=self.current_channels,
            planes=planes,
            stride=first_stride,
        )
        blocks.append(block_mod)

        # After the first block, the stage's channel width is planes * expansion
        self.current_channels = planes * self.block_cls.expansion

        # 2) Remaining blocks: keep same width and stride=1
        for _ in range(1, num_blocks):
            block_mod = self.block_cls(
                in_channels=self.current_channels,
                planes=planes,
                stride=1,
            )
            blocks.append(block_mod)

        return nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Stem
        x = self.stem_conv(x)
        x = self.stem_bn(x)
        x = self.stem_relu(x)
        x = self.stem_pool(x)

        # Stages
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # Head
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
```

Here in the stages, only in `layer1` we have `first_stride=1` and for the rest we have `first_stride=2`, this will set the height and width into `height/2` and `width/2` for the first convolution in the first block:


```python
self.layer1 = self._build_stage(planes=64,  num_blocks=layers_per_stage[0], first_stride=1)
self.layer2 = self._build_stage(planes=128, num_blocks=layers_per_stage[1], first_stride=2)
```


This will cause no changes in `H`,`W` in the first block: 


![](images/ResNet-18-First-Block.png)

but in the second, third, in the first convolution we have `height/2` and `width/2`:


![](images/ResNet-18-Second-Block.png)


For the second block, we set `stride=1` explicitly, so have no changes in `H`,`W`:

```python
for i in range(1, num_blocks):
            block_mod = self.block_cls(
                in_channels=self.current_channels,
                planes=planes,
                stride=1)
```




#### 2.3  ResNet18

```python
# ---------- Factory helpers ----------
def resnet18(num_classes: int = 1000, in_channels: int = 3) -> ResNet:
    """ResNet-18 uses BasicBlock with layers_per_stage = [2, 2, 2, 2]."""
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels)
```

![](images/ResNet-18-Architecture.png)



![](images/ResNet-18-architecture-diagram-conv.png)



#### ResNet34
```python
def resnet34(num_classes: int = 1000, in_channels: int = 3) -> ResNet:
    """ResNet-34 uses BasicBlock with layers_per_stage = [3, 4, 6, 3]."""
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels)
```    

```python
if __name__ == "__main__":
    model = resnet18(num_classes=10)
    x = torch.randn(2, 3, 224, 224)
    logits = model(x)
    print(type(BasicBlock))         # <class 'type'>  (a CLASS)
    print(isinstance(model.layer1[0], BasicBlock))  # True (an INSTANCE inside the stage)
    print(logits.shape)             # torch.Size([2, 10])
```




#### 2.4 Bottleneck Class

It’s the *other* residual block type used in ResNets 50/101/152.
Where `BasicBlock` does **two 3×3 convolutions**, the **Bottleneck** block does
**1×1 → 3×3 → 1×1** convolutions and sets `expansion = 4`.



* First 1×1 conv **reduces channels** (the “bottleneck”).
* Middle 3×3 conv does the heavy lifting at a reduced width (cheaper).
* Last 1×1 conv **expands channels back** to `planes * 4`.

So each block still outputs the full width, but most of the compute happens at a narrower “bottleneck” width.


This lets you build much deeper networks (50+ layers) without exploding compute.

---

#### 2.5 Typical Bottleneck class (PyTorch)

```python
import torch
import torch.nn as nn

class Bottleneck(nn.Module):
    expansion = 4  # output channels = planes * 4

    def __init__(self, in_channels: int, planes: int, stride: int = 1) -> None:
        super().__init__()
        out_channels = planes * self.expansion

        # 1x1 conv: reduce channels
        self.conv1 = nn.Conv2d(in_channels, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 3x3 conv: main processing
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 1x1 conv: expand channels back up
        self.conv3 = nn.Conv2d(planes, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)

        # skip connection projection if needed
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.skip = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)

        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)

        y = self.conv2(y)
        y = self.bn2(y)
        y = self.relu(y)

        y = self.conv3(y)
        y = self.bn3(y)

        y = y + identity
        y = self.relu(y)
        return y
```



or 

```python
class ResidualBlockBase(nn.Module):
    expansion: int = 1

class BasicBlock(ResidualBlockBase):
    expansion = 1
    # ... two 3x3 convs ...

class Bottleneck(ResidualBlockBase):
    expansion = 4
    # ... 1x1, 3x3, 1x1 convs ...
```



* `planes` = the reduced “bottleneck” width.
* `out_channels = planes * expansion = planes * 4`.

---

#### 2.6 resnet50

```python
# ResNet-50 uses Bottleneck blocks
resnet50 = ResNet(Bottleneck, [3, 4, 6, 3])
```



#### 2.7 What exactly do **18 / 34 / 50 / 101 / 152** mean” in **ResNet-X**.


They’re the **total number of weight-bearing layers** (convolutions + the final fully-connected layer) in the network.

* **ResNet-18** has 18 such layers.
* **ResNet-34** has 34 layers.
* **ResNet-50** has 50 layers.
* **ResNet-101** has 101 layers.
* **ResNet-152** has 152 layers.

Pooling layers and ReLU/BatchNorm don’t count toward the number.

---

| Model          | Block type           | Blocks per stage         | Convs per block | 1st conv | Total conv layers | +FC = total layers |
| -------------- | -------------------- | ------------------------ | --------------- | -------- | ----------------- | ------------------ |
| **ResNet-18**  | BasicBlock (2 convs) | [2, 2, 2, 2] = 8 blocks  | 2               | 1        | 1 + 8×2 = 17      | +1 FC = **18**     |
| **ResNet-34**  | BasicBlock (2 convs) | [3, 4, 6, 3] = 16 blocks | 2               | 1        | 1 + 16×2 = 33     | +1 FC = **34**     |
| **ResNet-50**  | Bottleneck (3 convs) | [3, 4, 6, 3] = 16 blocks | 3               | 1        | 1 + 16×3 = 49     | +1 FC = **50**     |
| **ResNet-101** | Bottleneck (3 convs) | [3, 4, 23,3]=33 blocks   | 3               | 1        | 1 + 33×3 =100     | +1 FC = **101**    |
| **ResNet-152** | Bottleneck (3 convs) | [3,8,36,3]=50 blocks     | 3               | 1        | 1 + 50×3 =151     | +1 FC = **152**    |

Breakdown:

* **First conv**: the 7×7 conv at the “stem” = 1 layer.
* **Blocks per stage**: how many residual blocks in conv2_x, conv3_x, conv4_x, conv5_x.
* **Convs per block**: 2 for BasicBlock, 3 for Bottleneck.
* Add them up + 1 for the final fully-connected = the ResNet number.

---


* ResNet-18/34 use BasicBlock (2 convs each).
* ResNet-50/101/152 use Bottleneck (3 convs each, but with 1×1–3×3–1×1 pattern).

So you can make the network deeper either by:

* Adding more blocks per stage ([3,4,6,3] vs [2,2,2,2]),
* Or switching to Bottleneck blocks with 3 convs each.

---



## 3. ResNet-X Input Size

The **“ResNet-18/34/50/101/152” names say nothing about input size**.
They only describe the *depth* (number of weight-bearing layers).

---

#### 3.1  The “default” input size in the paper / torchvision

All the ImageNet-trained ResNets assume **224×224 RGB images**:

```
Input: (N, 3, 224, 224)
Conv7×7 stride2 → (N, 64, 112, 112)
MaxPool stride2 → (N, 64, 56, 56)
layer1 (conv2_x) → (N, 64, 56, 56)
layer2 (conv3_x) → (N, 128, 28, 28)
layer3 (conv4_x) → (N, 256, 14, 14)
layer4 (conv5_x) → (N, 512, 7, 7)  or (planes*expansion)
GlobalAvgPool → (N, 512*expansion, 1, 1)
FC → (N, num_classes)
```

---

#### 3.2  But ResNet is **fully convolutional** until the FC layer

All layers before the final average pool are convolution + BN + ReLU.
So you can feed **any spatial size** (H×W) as long as it’s large enough for the downsamplings.

* Each stride-2 halves H and W.
* By the end of layer4 you’ve done stride (2×2×2×2=16) (or 32 if you count the first conv and maxpool), so your spatial size is roughly (\frac{H}{32} × \frac{W}{32}).

If the input is too small, you’ll hit zero or negative dimensions; otherwise it works.

---

**Typical requirements:**

* For ImageNet: 224×224 → final 7×7 feature map → global avg pool to 1×1.
* For CIFAR-10: inputs are 32×32. People use a **modified stem** (3×3 conv stride1 instead of 7×7 stride2 + maxpool) so they don’t downsample away everything.
* For segmentation: people remove the FC and keep the convolutional part as a “backbone” to get feature maps of arbitrary size.

---

**Rule of thumb for output spatial size**

If your input is ((N, 3, H, W)), after the four stages you’ll have:

$
H_{\text{out}} = \left\lfloor\frac{H}{32}\right\rfloor,
\quad
W_{\text{out}} = \left\lfloor\frac{W}{32}\right\rfloor
$

for the standard ResNet stem.
The channel dimension at the end = $512×\text{expansion}$.

Then global average pool reduces $H_\text{out}×W_\text{out}$ to 1×1 regardless of input.

---

* The “18/34/50/…” numbers do **not** encode input size.
* Default pretrained ResNets use 224×224 inputs.
* Because ResNet is convolutional + global pooling, you can use larger or smaller images (e.g. 256×256, 512×512) — you just get correspondingly larger or smaller feature maps before pooling.
* For very small inputs (like CIFAR-10), you typically adjust the stem to avoid over-downsampling.


#### 3.3 Pretrained ResNet Weight Input Size

If you download a pretrained ResNet weight from PyTorch, do you have to feed it exactly 224×224 images?

All the conv, BN, ReLU layers in ResNet have kernels and weights that don’t depend on the spatial size of the input.
They just slide across whatever height × width you give them.

So the pretrained weights from `torchvision.models.resnet50(pretrained=True)` will happily process, say, $(3,256,256)$ or $(3,512,512)$ images.
You don’t need to change anything for the conv layers.

---

#### 3.4 The **final fully-connected layer** *does* expect a fixed input width

But ResNet’s last step before the FC layer is a **global average pooling**:

```python
x = self.avgpool(x)   # (N, C, H_out, W_out) → (N, C, 1, 1)
x = torch.flatten(x, 1)  # (N, C)
x = self.fc(x)        # (N, num_classes)
```

That `AdaptiveAvgPool2d((1,1))` always gives you a (C,1,1) tensor no matter what `H_out,W_out` were.
Flatten → (N,C) → FC layer.
So the FC sees exactly `C = 512*expansion` features no matter what the input size was.

That’s why the pretrained FC layer still works.

---


#### 3.5 If you change the number of input channels

If you go from RGB (3 channels) to grayscale (1 channel) or multispectral (5 channels), the **first conv**’s weights won’t match.
You either:

* Replace the first conv and randomly init it, or
* Load pretrained weights partially and adapt the first conv.




In [5]:
import torch
import torchvision
import torchviz


#ResNet18_Weights.DEFAULT is equivalent to ResNet18_Weights.IMAGENET1K_V1. 
resnet18=torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT', progress= True)

In [6]:

print("resnet18 input size: ", resnet18.fc.in_features)
print("resnet18 output size: ",resnet18.fc.out_features)


resnet18 input size:  512
resnet18 output size:  1000


In [7]:
#  resnet18 has an averagepool layer at the end.
#  So the input size does not matter much provided the feature map size is greater than kernel size.

input=torch.randn(size=[1,3,128,128])

resnet18_graph=torchviz.make_dot(resnet18(input) ,dict(resnet18.named_parameters()))
resnet18_graph.format='svg'
resnet18_graph.save('images/resnet18_graph')
resnet18_graph.render()



'images/resnet18_graph.svg'

![](images/resnet18_graph.svg)

## 4. Finetune the model on a new dataset with 10 labels


Let’s say we want to finetune the model on a new dataset with `10` labels. In resnet, the classifier is the last linear layer `model.fc.` We can simply replace it with a new linear layer (unfrozen by default) that acts as our classifier.


```python
for params in resnet18.parameters():
    params.requiers_gard=False

resnet18.fc=torch.nn.Linear(512,10)

```

Now all parameters in the model, except the parameters of `model.fc`, are frozen. The only parameters that compute gradients are the `weights` and `bias` of `model.fc.`

```python
optimizer=torch.optim.SGD(resnet18.fc.parameters(),lr=1e-2,momentum=0.9)
```

## Black and white Image Input

The pretrained ResNet-18 expects 3-channel RGB at `conv1`. With a single-channel (monochrome) input you have a few good options:

### 1) Duplicate the channel (fastest, no weight surgery)

Preprocess your 1-channel image to 3 channels by repeating it. `transforms.Grayscale`:
- If num_output_channels == 1 : returned image is single channel
- If num_output_channels == 3 : returned image is 3 channel with r == g == b

```python
# during transforms
from torchvision import transforms

tfm = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # 1→3 by duplication
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),  # ImageNet stats
])
```

Pros: zero code change to the model; you keep pretrained weights intact.
Cons: a tiny bit redundant, but works very well in practice.

### 2) Replace `conv1` with 1 input channel and **port** pretrained weights

Average the RGB kernels into a single channel:

```python
import torch
import torchvision.models as models
import torch.nn as nn

m = models.resnet18(weights='models.ResNet18_Weights.IMAGENET1K_V1')
w = m.conv1.weight  # [64, 3, 7, 7]

m.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

with torch.no_grad():
    m.conv1.weight[:] = w.mean(dim=1, keepdim=True)  # [64,1,7,7]
```

(You can also use a weighted sum like `0.2989 R + 0.5870 G + 0.1140 B` instead of `.mean`.)

Pros: no wasted computation; uses pretrained filters sensibly.
Cons: a tiny bit of code; but this is the cleanest if you’re truly single-channel end-to-end.

### 3) Add a learnable 1→3 adapter in front

Keep the pretrained model intact and learn a shallow mapping:

```python
class GrayToRGB(nn.Module):
    def __init__(self):
        super().__init__()
        self.map = nn.Conv2d(1, 3, kernel_size=1, bias=False)
        nn.init.constant_(self.map.weight, 1/3)  # start as “repeat”

    def forward(self, x):
        return self.map(x)

adapter = GrayToRGB()
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
full = nn.Sequential(adapter, model)
```

Pros: lets the network learn the best 1→3 projection.
Cons: a few extra params; slightly more moving parts.

---

### Normalization notes

* If you **duplicate to 3-ch**, you can keep ImageNet mean/std as above, or compute dataset-specific stats and use those for all 3 channels (same numbers repeated).
* If you **switch to 1-ch conv1**, use single mean/std (e.g., `(mean,)` and `(std,)`) matching your grayscale dataset.

### Which should you pick?

* **Quick wins / transfer learning**: Option **1** (duplicate) is perfectly fine and very common.
* **Purist, minimal compute**: Option **2** (port weights) is elegant and usually performs best.
* **Data shift concerns** (e.g., MRI/CT with unusual intensity): Option **3** gives flexibility; also consider dataset-specific normalization and fine-tuning early layers.

### How to train the option 3 Network (learnable 1→3 adapter in front)

With option 3 (a learnable 1→3 “adapter” in front of a pretrained ResNet-18), you’ve got three common training strategies. Pick one based on data size and how different your grayscale data is from ImageNet.

### A. Freeze backbone first, train adapter + head (safe start)

1–3 epochs:

* Freeze **all** ResNet18 params.
* Train only the `GrayToRGB` adapter and the final `fc`.

Then unfreeze the backbone (optionally with a lower LR) and fine-tune.

```python
import torch.nn as nn
import torchvision.models as models

class GrayToRGB(nn.Module):
    def __init__(self):
        super().__init__()
        self.map = nn.Conv2d(1, 3, kernel_size=1, bias=False)
        nn.init.constant_(self.map.weight, 1/3)

    def forward(self, x): return self.map(x)

backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
backbone.fc = nn.Linear(backbone.fc.in_features, n_classes)  # replace head

model = nn.Sequential(GrayToRGB(), backbone)

# --- Phase 1: freeze backbone ---
for p in backbone.parameters():
    p.requires_grad = False

# optimize only adapter + fc
params = list(model[0].parameters()) + list(backbone.fc.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay=1e-4)
```

After a few epochs:

```python
# --- Phase 2: unfreeze backbone with smaller LR ---
for p in backbone.parameters():
    p.requires_grad = True

optimizer = torch.optim.Adam([
    {"params": model[0].parameters(), "lr": 1e-3},          # adapter
    {"params": backbone.layer1.parameters(), "lr": 5e-4},
    {"params": backbone.layer2.parameters(), "lr": 2.5e-4},
    {"params": backbone.layer3.parameters(), "lr": 1.25e-4},
    {"params": backbone.layer4.parameters(), "lr": 1.25e-4},
    {"params": backbone.fc.parameters(),     "lr": 1e-3},
], weight_decay=1e-4)
```

### B. Train everything, but with **discriminative learning rates** (faster)

Good when you have a moderate dataset and want quick convergence without a freezing phase.

```python
optimizer = torch.optim.AdamW([
    {"params": model[0].parameters(),           "lr": 1e-3},  # adapter highest
    {"params": backbone.layer1.parameters(),    "lr": 5e-4},
    {"params": backbone.layer2.parameters(),    "lr": 3e-4},
    {"params": backbone.layer3.parameters(),    "lr": 2e-4},
    {"params": backbone.layer4.parameters(),    "lr": 2e-4},
    {"params": backbone.fc.parameters(),        "lr": 1e-3},  # head high
], weight_decay=1e-4)
```

### C. Unfreeze progressively (“gradual unfreezing”)

Start with only adapter+fc, then unfreeze layers one block at a time every few epochs (layer4 → layer3 → …). This is handy with small datasets.

---

### Do we freeze the adapter conv?

* **Usually not.** Let it learn a smart projection beyond simple channel copy.
* If the dataset is tiny and unstable, you *can* freeze it for the first few hundred steps.

### BatchNorm tips

* If batch size is small (≤16), consider putting the backbone’s BN layers in **eval** mode during early training:

  ```python
  def set_bn_eval(m):
      if isinstance(m, nn.BatchNorm2d):
          m.eval()
  backbone.apply(set_bn_eval)
  ```

  (Params can still be trainable; this just freezes running stats.)

### Weight decay hygiene (optional but nice)

Avoid weight decay on BN and bias:

```python
decay, no_decay = [], []
for n, p in model.named_parameters():
    if not p.requires_grad: continue
    if n.endswith('bias') or 'bn' in n.lower():
        no_decay.append(p)
    else:
        decay.append(p)
optimizer = torch.optim.AdamW([
    {"params": decay, "weight_decay": 1e-4, "lr": 3e-4},
    {"params": no_decay, "weight_decay": 0.0, "lr": 3e-4},
])
```

### Schedulers (keep it simple)

* **Cosine with warmup** or **OneCycleLR** both work well:

```python
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
# or:
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(loader), epochs=E)
```

### Quick guidance

* **Small dataset / big domain shift (e.g., MRI):** A or C.
* **Moderate dataset / some domain shift:** B with discriminative LRs.
* **Plenty of data:** Train all, normal LRs, standard fine-tune.

