# *SplitGP Implementation (Personlized FL)*

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Auxiliary classifier head
class AuxClassifier(nn.Module):
    def __init__(self, in_channels, num_classes):

        super(AuxClassifier, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))   # shrink spatial size
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.avgpool(x)         # shape -> [B, in_channels, 4, 4]
        x = F.relu(self.conv(x))    # reduce channels
        x = torch.flatten(x, 1)     # flatten
        x = F.relu(self.fc1(x))     # hidden layer
        x = self.fc2(x)             # logits
        return x


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Auxiliary classifier head
class AuxClassifier(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(AuxClassifier, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))   # shrink spatial size
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.avgpool(x)         # shape -> [B, in_channels, 4, 4]
        x = F.relu(self.conv(x))    # reduce channels
        x = torch.flatten(x, 1)     # flatten
        x = F.relu(self.fc1(x))     # hidden layer
        x = self.fc2(x)             # logits
        return x

class NetWithAux(nn.Module):
    def __init__(self, num_classes=10):
        super(NetWithAux, self).__init__()

        # First block
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)   # output: [B, 64, 16, 16] for 32x32 input
        )

        # Auxiliary classifier takes features from block1
        self.aux_head = AuxClassifier(64, num_classes)

        # Second block (main path continues)
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)   # output: [B, 128, 8, 8]
        )

        # Main classifier
        self.fc_main = nn.Linear(128 * 8 * 8, num_classes)

    def forward(self, x):
        x = self.block1(x)

        # Auxiliary prediction from intermediate feature map
        aux_out = self.aux_head(x)

        # Continue main path
        x = self.block2(x)
        x = torch.flatten(x, 1)
        main_out = self.fc_main(x)

        return main_out, aux_out


In [16]:
model = NetWithAux(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


# Fake batch (like CIFAR-10)
images = torch.randn(16, 3, 32, 32)
labels = torch.randint(0, 10, (16,))


# Forward pass
main_out, aux_out = model(images)

loss_main = criterion(main_out, labels)
loss_aux = criterion(aux_out, labels)

# Combine with weight for aux loss
loss = loss_main + 0.3 * loss_aux

# Backprop + update
optimizer.zero_grad()
loss.backward()
optimizer.step()
