In [1]:
# Tuneable Params
lr = 1e-3
data_dir = "data_3_class"
model_name = "b0"
save_logs = True
epochs = 100
rotate_angle=None
horizontal_flip_prob=None
brightess_contrast=None
gaussian_blur=None
normalize=True
seed = 42
batch_size = 32
results_folder_name = "3_class_results_leaky"
truncated_layers = 3
bootstrap_n = 1000
pretrained = True


In [2]:
import torch.nn as nn
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights, EfficientNet_B1_Weights, EfficientNet_B2_Weights, EfficientNet_B3_Weights

# Define the model mapping as a constant (outside the function)
MODEL_MAPPING = {
    "b0": ("efficientnet_b0", EfficientNet_B0_Weights.IMAGENET1K_V1),
    "b1": ("efficientnet_b1", EfficientNet_B1_Weights.IMAGENET1K_V1),
    "b2": ("efficientnet_b2", EfficientNet_B2_Weights.IMAGENET1K_V1),
    "b3": ("efficientnet_b3", EfficientNet_B3_Weights.IMAGENET1K_V1),
}

def load_efficientnet(model_name, model_mapping, pretrained):
    """
    Load an EfficientNet model based on the provided model name and model mapping.

    Args:
        model_name (str): The name of the EfficientNet model (e.g., "b0", "b1", "b2", "b3").
        model_mapping (dict): A dictionary mapping model names to their corresponding classes and weights.

    Returns:
        torch.nn.Module: The loaded EfficientNet model.

    Raises:
        ValueError: If the model name is not supported.
    """
    # Check if the model name is valid
    if model_name not in model_mapping:
        raise ValueError(f"Unsupported model name: {model_name}. Supported models are: {list(model_mapping.keys())}")

    # Get the model class and weights from the mapping
    model_class_name, weights = model_mapping[model_name]
    model_class = getattr(models, model_class_name)

    if pretrained:
        # Load the model with pretrained weights
        effnet = model_class(weights=weights)
    else:
        effnet = model_class(weights=None)
    return effnet


try:
    effnet = load_efficientnet(model_name, MODEL_MAPPING, pretrained=pretrained)
    print(f"Successfully loaded EfficientNet {model_name}.")
except ValueError as e:
    print(e)

Successfully loaded EfficientNet b0.


In [3]:
block0 = list(effnet.features.children())[0]
block1 = list(effnet.features.children())[1]
block2 = list(effnet.features.children())[2]
block3 = list(effnet.features.children())[3]
block3 = list(effnet.features.children())[4]
block5 = list(effnet.features.children())[5]
block6 = list(effnet.features.children())[6]
block7 = list(effnet.features.children())[7]
block8 = list(effnet.features.children())[8]


In [4]:
block0[0].in_channels

3

In [7]:
block8

Conv2dNormActivation(
  (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): SiLU(inplace=True)
)

In [None]:


class TruncatedEffNet(nn.Module):
    def __init__(self, effnet, num_classes, removed_layers, batch_size, image_size):
        super(TruncatedEffNet, self).__init__()

        # Truncate the EfficientNet backbone
        layers = 7 - removed_layers
        self.effnet_truncated = nn.Sequential(*list(effnet.features.children())[:layers])

        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        # Dynamically calculate the input size for the fully connected layer
        with torch.no_grad():  # Disable gradient tracking for this forward pass
            dummy_input = torch.randn(batch_size, 3, image_size, image_size)  # Example input (batch_size=1, channels=3, height=224, width=224)
            dummy_output = self.effnet_truncated(dummy_input)
            dummy_output = self.global_avg_pool(dummy_output)
            fc_input_size = dummy_output.view(dummy_output.size(0), -1).size(1)  # Flatten and get the size

        self.classifier = nn.Sequential(
            nn.Dropout(.2),
            nn.Linear(fc_input_size, num_classes)
        )

        self.fc_lrelu = nn.Sequential(
                    nn.Linear(fc_input_size, 128),
                    nn.LeakyReLU(),
                    nn.Dropout(0.2),
                    nn.Linear(128, num_classes)
                    )   
        # Define the fully connected layer
        self.fc = nn.Linear(fc_input_size, num_classes)

    def forward(self, x):
        x = self.effnet_truncated(x)  # Extract features
        x = self.global_avg_pool(x)  # Pooling
        x = x.view(x.size(0), -1)  # Flatten
        # x = self.classifier(x)  # Classification
        x = self.fc_lrelu(x)
        return x

# Instantiate the model with the truncated backbone
model = TruncatedEffNet(effnet, num_classes, removed_layers=truncated_layers, batch_size=batch_size, image_size=image_size)