In [1]:
import torch
import torch.nn as nn


class ResNet50Encoder(nn.Module):
    def __init__(self):
        super(ResNet50Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual blocks
        self.layer1 = self._make_layer(64, 64, 3)
        self.layer2 = self._make_layer(256, 128, 4, stride=2)
        self.layer3 = self._make_layer(512, 256, 6, stride=2)
        self.layer4 = self._make_layer(1024, 512, 3, stride=2)

    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        layers = []
        layers.append(BottleneckResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, blocks):
            layers.append(BottleneckResidualBlock(out_channels * 4, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x


class BottleneckResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BottleneckResidualBlock, self).__init__()
        mid_channels = out_channels
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = nn.Conv2d(
            mid_channels,
            mid_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.conv3 = nn.Conv2d(
            mid_channels, out_channels * 4, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels * 4)
        self.relu = nn.ReLU(inplace=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * 4:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels * 4,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels * 4),
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


# Example
net = ResNet50Encoder()

input = torch.randn(1, 3, 512, 512)

net(input)

In [5]:
from torchvision.models import resnet50, ResNet50_Weights

resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

resnet.fc = torch.nn.Linear(
    resnet.fc.in_features, 2
)  #  Building base / Not building base

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.001, momentum=0.9)

# Training loop
num_epochs = 25
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in dataloader:
        optimizer.zero_grad()

        outputs = resnet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch {epoch}/{num_epochs - 1}, Loss: {epoch_loss:.4f}")

NameError: name 'dataloader' is not defined

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0792, 0.0000,  ..., 1.7765, 0.5342, 1.1520],
          [0.0000, 0.8803, 1.5960,  ..., 0.3809, 1.4557, 0.0000],
          ...,
          [0.0000, 1.1959, 0.0000,  ..., 0.3367, 1.8031, 0.0000],
          [1.8454, 1.9882, 1.3817,  ..., 0.0000, 0.1171, 0.0000],
          [0.0000, 1.3473, 0.0000,  ..., 0.0000, 0.3973, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.2125, 0.0000],
          [0.8841, 0.8949, 0.0000,  ..., 3.3369, 0.6026, 1.8078],
          [0.0000, 0.0000, 1.7070,  ..., 0.8898, 1.1849, 0.4476],
          ...,
          [0.2693, 0.1573, 2.1569,  ..., 1.2825, 0.0667, 0.0000],
          [0.0000, 0.5049, 0.0489,  ..., 0.0000, 1.6983, 0.9831],
          [1.1037, 0.2746, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.4542, 0.3448, 0.4521,  ..., 0.0000, 0.0000, 0.0395],
          [0.1326, 0.4347, 0.5358,  ..., 1.1746, 0.1234, 2.1520],
          [1.8016, 2.4450, 0.2534,  ..., 0